HsiangKai created this revision.
HsiangKai added reviewers: craig.topper, frasercrmck, rogfer01, kito-cheng.
Herald added subscribers: StephenFan, vkmr, dexonsmith, evandro, luismarques, 
apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, PkmX, the_o, 
brucehoult, MartinMosbeck, edward-jones, zzheng, jrtc27, shiva0217, niosHD, 
sabuasal, simoncook, johnrusso, rbar, asb, hiraditya.
HsiangKai requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, MaskRay.
Herald added projects: clang, LLVM.

This is a proof-of-concept patch. It does not add the tail policy
argument to all the builtins/intrinsics. This patch uses vadd as an
example to add the tail policy argument.

I added several new classes. There is no need to add these classes in
the target description. I do so just to limit the modification to vadd
only.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D105092

Files:
  clang/include/clang/Basic/riscv_vector.td
  clang/test/CodeGen/RISCV/rvv-intrinsics/vadd-policy.c
  clang/utils/TableGen/RISCVVEmitter.cpp
  llvm/include/llvm/IR/IntrinsicsRISCV.td
  llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
  llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
  llvm/lib/Target/RISCV/RISCVInstrFormats.td
  llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
  llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
  llvm/test/CodeGen/RISCV/rvv/vadd-policy.ll

Index: llvm/test/CodeGen/RISCV/rvv/vadd-policy.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/RISCV/rvv/vadd-policy.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+experimental-v -verify-machineinstrs \
+; RUN:   --riscv-no-aliases < %s | FileCheck %s
+
+declare <vscale x 8 x i8> @llvm.riscv.vadd.nxv8i8.nxv8i8(
+  <vscale x 8 x i8>,
+  <vscale x 8 x i8>,
+  i64);
+
+define <vscale x 8 x i8> @intrinsic_vadd_vv_nxv8i8_nxv8i8_nxv8i8(<vscale x 8 x i8> %0, <vscale x 8 x i8> %1, i64 %2) nounwind {
+; CHECK-LABEL: intrinsic_vadd_vv_nxv8i8_nxv8i8_nxv8i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetvli zero, a0, e8, m1, ta, mu
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    jalr zero, 0(ra)
+entry:
+  %a = call <vscale x 8 x i8> @llvm.riscv.vadd.nxv8i8.nxv8i8(
+    <vscale x 8 x i8> %0,
+    <vscale x 8 x i8> %1,
+    i64 %2)
+
+  ret <vscale x 8 x i8> %a
+}
+
+declare <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8(
+  <vscale x 8 x i8>,
+  <vscale x 8 x i8>,
+  <vscale x 8 x i8>,
+  <vscale x 8 x i1>,
+  i64, i64);
+
+define <vscale x 8 x i8> @intrinsic_vadd_mask_tu(<vscale x 8 x i8> %0, <vscale x 8 x i8> %1, <vscale x 8 x i8> %2, <vscale x 8 x i1> %3, i64 %4) nounwind {
+; CHECK-LABEL: intrinsic_vadd_mask_tu:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 3, e8, m1, tu, mu
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
+; CHECK-NEXT:    jalr zero, 0(ra)
+entry:
+  %a = call <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8(
+    <vscale x 8 x i8> %0,
+    <vscale x 8 x i8> %1,
+    <vscale x 8 x i8> %2,
+    <vscale x 8 x i1> %3,
+    i64 %4, i64 0)
+
+  ret <vscale x 8 x i8> %a
+}
+
+define <vscale x 8 x i8> @intrinsic_vadd_mask_ta(<vscale x 8 x i8> %0, <vscale x 8 x i8> %1, <vscale x 8 x i8> %2, <vscale x 8 x i1> %3, i64 %4) nounwind {
+; CHECK-LABEL: intrinsic_vadd_mask_ta:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 3, e8, m1, ta, mu
+; CHECK-NEXT:    vadd.vv v8, v9, v10, v0.t
+; CHECK-NEXT:    jalr zero, 0(ra)
+entry:
+  %a = call <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8(
+    <vscale x 8 x i8> %0,
+    <vscale x 8 x i8> %1,
+    <vscale x 8 x i8> %2,
+    <vscale x 8 x i1> %3,
+    i64 %4, i64 1)
+
+  ret <vscale x 8 x i8> %a
+}
+
Index: llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -269,6 +269,36 @@
                      VMV0:$vm, GPR:$vl, sew)>;
 }
 
+multiclass VPatBinaryVL_VV_WithPolicy<SDNode vop,
+                                      string instruction_name,
+                                      ValueType result_type,
+                                      ValueType op_type,
+                                      ValueType mask_type,
+                                      int sew,
+                                      LMULInfo vlmul,
+                                      VReg RetClass,
+                                      VReg op_reg_class> {
+  def : Pat<(result_type (vop
+                         (op_type op_reg_class:$rs1),
+                         (op_type op_reg_class:$rs2),
+                         (mask_type true_mask),
+                         VLOpFrag)),
+            (!cast<Instruction>(instruction_name#"_VV_"# vlmul.MX)
+                         op_reg_class:$rs1,
+                         op_reg_class:$rs2,
+                         GPR:$vl, sew)>;
+  def : Pat<(result_type (vop
+                         (op_type op_reg_class:$rs1),
+                         (op_type op_reg_class:$rs2),
+                         (mask_type VMV0:$vm),
+                         VLOpFrag)),
+        (!cast<Instruction>(instruction_name#"_VV_"# vlmul.MX#"_MASK")
+                     (result_type (IMPLICIT_DEF)),
+                     op_reg_class:$rs1,
+                     op_reg_class:$rs2,
+                     VMV0:$vm, GPR:$vl, sew, 0)>;
+}
+
 multiclass VPatBinaryVL_XI<SDNode vop,
                            string instruction_name,
                            string suffix,
@@ -302,6 +332,39 @@
                      VMV0:$vm, GPR:$vl, sew)>;
 }
 
+multiclass VPatBinaryVL_XI_WithPolicy<SDNode vop,
+                                      string instruction_name,
+                                      string suffix,
+                                      ValueType result_type,
+                                      ValueType vop_type,
+                                      ValueType mask_type,
+                                      int sew,
+                                      LMULInfo vlmul,
+                                      VReg RetClass,
+                                      VReg vop_reg_class,
+                                      ComplexPattern SplatPatKind,
+                                      DAGOperand xop_kind> {
+  def : Pat<(result_type (vop
+                     (vop_type vop_reg_class:$rs1),
+                     (vop_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+                     (mask_type true_mask),
+                     VLOpFrag)),
+        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX)
+                     vop_reg_class:$rs1,
+                     xop_kind:$rs2,
+                     GPR:$vl, sew)>;
+  def : Pat<(result_type (vop
+                     (vop_type vop_reg_class:$rs1),
+                     (vop_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+                     (mask_type VMV0:$vm),
+                     VLOpFrag)),
+        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
+                     (result_type (IMPLICIT_DEF)),
+                     vop_reg_class:$rs1,
+                     xop_kind:$rs2,
+                     VMV0:$vm, GPR:$vl, sew, 0)>;
+}
+
 multiclass VPatBinaryVL_VV_VX<SDNode vop, string instruction_name> {
   foreach vti = AllIntegerVectors in {
     defm : VPatBinaryVL_VV<vop, instruction_name,
@@ -314,6 +377,18 @@
   }
 }
 
+multiclass VPatBinaryVL_VV_VX_WithPolicy<SDNode vop, string instruction_name> {
+  foreach vti = AllIntegerVectors in {
+    defm : VPatBinaryVL_VV_WithPolicy<vop, instruction_name,
+                           vti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+                           vti.LMul, vti.RegClass, vti.RegClass>;
+    defm : VPatBinaryVL_XI_WithPolicy<vop, instruction_name, "VX",
+                           vti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+                           vti.LMul, vti.RegClass, vti.RegClass,
+                           SplatPat, GPR>;
+  }
+}
+
 multiclass VPatBinaryVL_VV_VX_VI<SDNode vop, string instruction_name,
                                  Operand ImmType = simm5>
     : VPatBinaryVL_VV_VX<vop, instruction_name> {
@@ -326,6 +401,18 @@
   }
 }
 
+multiclass VPatBinaryVL_VV_VX_VI_WithPolicy<SDNode vop, string instruction_name,
+                                 Operand ImmType = simm5>
+    : VPatBinaryVL_VV_VX_WithPolicy<vop, instruction_name> {
+  foreach vti = AllIntegerVectors in {
+    defm : VPatBinaryVL_XI_WithPolicy<vop, instruction_name, "VI",
+                           vti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+                           vti.LMul, vti.RegClass, vti.RegClass,
+                           !cast<ComplexPattern>(SplatPat#_#ImmType),
+                           ImmType>;
+  }
+}
+
 class VPatBinaryVL_VF<SDNode vop,
                       string instruction_name,
                       ValueType result_type,
@@ -589,7 +676,7 @@
 }
 
 // 12.1. Vector Single-Width Integer Add and Subtract
-defm : VPatBinaryVL_VV_VX_VI<riscv_add_vl, "PseudoVADD">;
+defm : VPatBinaryVL_VV_VX_VI_WithPolicy<riscv_add_vl, "PseudoVADD">;
 defm : VPatBinaryVL_VV_VX<riscv_sub_vl, "PseudoVSUB">;
 // Handle VRSUB specially since it's the only integer binary op with reversed
 // pattern operands
Index: llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -975,6 +975,26 @@
   let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
 }
 
+class VPseudoBinaryMaskWithPolicy<VReg RetClass,
+                                  RegisterClass Op1Class,
+                                  DAGOperand Op2Class,
+                                  string Constraint> :
+        Pseudo<(outs GetVRegNoV0<RetClass>.R:$rd),
+                (ins GetVRegNoV0<RetClass>.R:$merge,
+                     Op1Class:$rs2, Op2Class:$rs1,
+                     VMaskOp:$vm, AVL:$vl, ixlenimm:$sew, uimm5:$policy), []>,
+        RISCVVPseudo {
+  let mayLoad = 0;
+  let mayStore = 0;
+  let hasSideEffects = 0;
+  let Constraints = Join<[Constraint, "$rd = $merge"], ",">.ret;
+  let HasVLOp = 1;
+  let HasSEWOp = 1;
+  let HasMergeOp = 1;
+  let HasPolicy = true;
+  let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
+}
+
 // Like VPseudoBinaryMask, but output can be V0.
 class VPseudoBinaryMOutMask<VReg RetClass,
                             RegisterClass Op1Class,
@@ -1496,6 +1516,19 @@
   }
 }
 
+multiclass VPseudoBinaryWithPolicy<VReg RetClass,
+                                   VReg Op1Class,
+                                   DAGOperand Op2Class,
+                                   LMULInfo MInfo,
+                                   string Constraint = ""> {
+  let VLMul = MInfo.value in {
+    def "_" # MInfo.MX : VPseudoBinaryNoMask<RetClass, Op1Class, Op2Class,
+                                             Constraint>;
+    def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskWithPolicy<RetClass, Op1Class, Op2Class,
+                                                               Constraint>;
+  }
+}
+
 multiclass VPseudoBinaryM<VReg RetClass,
                           VReg Op1Class,
                           DAGOperand Op2Class,
@@ -1541,6 +1574,11 @@
     defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint>;
 }
 
+multiclass VPseudoBinaryV_VV_WithPolicy<string Constraint = ""> {
+  foreach m = MxList.m in
+    defm _VV : VPseudoBinaryWithPolicy<m.vrclass, m.vrclass, m.vrclass, m, Constraint>;
+}
+
 multiclass VPseudoBinaryV_VV_EEW<int eew, string Constraint = ""> {
   foreach m = MxList.m in {
     foreach sew = EEWList in {
@@ -1561,6 +1599,11 @@
     defm "_VX" : VPseudoBinary<m.vrclass, m.vrclass, GPR, m, Constraint>;
 }
 
+multiclass VPseudoBinaryV_VX_WithPolicy<string Constraint = ""> {
+  foreach m = MxList.m in
+    defm "_VX" : VPseudoBinaryWithPolicy<m.vrclass, m.vrclass, GPR, m, Constraint>;
+}
+
 multiclass VPseudoBinaryV_VF<string Constraint = ""> {
   foreach m = MxList.m in
     foreach f = FPList.fpinfo in
@@ -1573,6 +1616,11 @@
     defm _VI : VPseudoBinary<m.vrclass, m.vrclass, ImmType, m, Constraint>;
 }
 
+multiclass VPseudoBinaryV_VI_WithPolicy<Operand ImmType = simm5, string Constraint = ""> {
+  foreach m = MxList.m in
+    defm _VI : VPseudoBinaryWithPolicy<m.vrclass, m.vrclass, ImmType, m, Constraint>;
+}
+
 multiclass VPseudoBinaryM_MM {
   foreach m = MxList.m in
     let VLMul = m.value in {
@@ -1801,6 +1849,12 @@
   defm "" : VPseudoBinaryV_VI<ImmType, Constraint>;
 }
 
+multiclass VPseudoBinaryV_VV_VX_VI_WithPolicy<Operand ImmType = simm5, string Constraint = ""> {
+  defm "" : VPseudoBinaryV_VV_WithPolicy<Constraint>;
+  defm "" : VPseudoBinaryV_VX_WithPolicy<Constraint>;
+  defm "" : VPseudoBinaryV_VI_WithPolicy<ImmType, Constraint>;
+}
+
 multiclass VPseudoBinaryV_VV_VX {
   defm "" : VPseudoBinaryV_VV;
   defm "" : VPseudoBinaryV_VX;
@@ -2309,6 +2363,28 @@
                    (op2_type op2_kind:$rs2),
                    (mask_type V0), GPR:$vl, sew)>;
 
+class VPatBinaryMaskWithPolicy<string intrinsic_name,
+                               string inst,
+                               ValueType result_type,
+                               ValueType op1_type,
+                               ValueType op2_type,
+                               ValueType mask_type,
+                               int sew,
+                               VReg result_reg_class,
+                               VReg op1_reg_class,
+                               DAGOperand op2_kind> :
+  Pat<(result_type (!cast<Intrinsic>(intrinsic_name#"_mask")
+                   (result_type result_reg_class:$merge),
+                   (op1_type op1_reg_class:$rs1),
+                   (op2_type op2_kind:$rs2),
+                   (mask_type V0),
+                   VLOpFrag, (XLenVT uimm5:$policy))),
+                   (!cast<Instruction>(inst#"_MASK")
+                   (result_type result_reg_class:$merge),
+                   (op1_type op1_reg_class:$rs1),
+                   (op2_type op2_kind:$rs2),
+                   (mask_type V0), GPR:$vl, sew, (XLenVT uimm5:$policy))>;
+
 // Same as above but source operands are swapped.
 class VPatBinaryMaskSwapped<string intrinsic_name,
                             string inst,
@@ -2565,6 +2641,24 @@
                        op2_kind>;
 }
 
+multiclass VPatBinaryWithPolicy<string intrinsic,
+                                string inst,
+                                ValueType result_type,
+                                ValueType op1_type,
+                                ValueType op2_type,
+                                ValueType mask_type,
+                                int sew,
+                                VReg result_reg_class,
+                                VReg op1_reg_class,
+                                DAGOperand op2_kind>
+{
+  def : VPatBinaryNoMask<intrinsic, inst, result_type, op1_type, op2_type,
+                         sew, op1_reg_class, op2_kind>;
+  def : VPatBinaryMaskWithPolicy<intrinsic, inst, result_type, op1_type, op2_type,
+                                 mask_type, sew, result_reg_class, op1_reg_class,
+                                 op2_kind>;
+}
+
 multiclass VPatBinarySwapped<string intrinsic,
                       string inst,
                       ValueType result_type,
@@ -2653,6 +2747,15 @@
                       vti.RegClass, vti.RegClass>;
 }
 
+multiclass VPatBinaryV_VV_WithPolicy<string intrinsic, string instruction,
+                          list<VTypeInfo> vtilist> {
+  foreach vti = vtilist in
+    defm : VPatBinaryWithPolicy<intrinsic, instruction # "_VV_" # vti.LMul.MX,
+                      vti.Vector, vti.Vector, vti.Vector,vti.Mask,
+                      vti.Log2SEW, vti.RegClass,
+                      vti.RegClass, vti.RegClass>;
+}
+
 multiclass VPatBinaryV_VV_INT<string intrinsic, string instruction,
                           list<VTypeInfo> vtilist> {
   foreach vti = vtilist in {
@@ -2694,6 +2797,17 @@
   }
 }
 
+multiclass VPatBinaryV_VX_WithPolicy<string intrinsic, string instruction,
+                          list<VTypeInfo> vtilist> {
+  foreach vti = vtilist in {
+    defvar kind = "V"#vti.ScalarSuffix;
+    defm : VPatBinaryWithPolicy<intrinsic, instruction#"_"#kind#"_"#vti.LMul.MX,
+                      vti.Vector, vti.Vector, vti.Scalar, vti.Mask,
+                      vti.Log2SEW, vti.RegClass,
+                      vti.RegClass, vti.ScalarRegClass>;
+  }
+}
+
 multiclass VPatBinaryV_VX_INT<string intrinsic, string instruction,
                           list<VTypeInfo> vtilist> {
   foreach vti = vtilist in
@@ -2712,6 +2826,15 @@
                       vti.RegClass, imm_type>;
 }
 
+multiclass VPatBinaryV_VI_WithPolicy<string intrinsic, string instruction,
+                                     list<VTypeInfo> vtilist, Operand imm_type> {
+  foreach vti = vtilist in
+    defm : VPatBinaryWithPolicy<intrinsic, instruction # "_VI_" # vti.LMul.MX,
+                                vti.Vector, vti.Vector, XLenVT, vti.Mask,
+                                vti.Log2SEW, vti.RegClass,
+                                vti.RegClass, imm_type>;
+}
+
 multiclass VPatBinaryM_MM<string intrinsic, string instruction> {
   foreach mti = AllMasks in
     def : VPatBinaryNoMask<intrinsic, instruction # "_MM_" # mti.LMul.MX,
@@ -2914,6 +3037,12 @@
       VPatBinaryV_VX<intrinsic, instruction, vtilist>,
       VPatBinaryV_VI<intrinsic, instruction, vtilist, ImmType>;
 
+multiclass VPatBinaryV_VV_VX_VI_WithPolicy<string intrinsic, string instruction,
+                                list<VTypeInfo> vtilist, Operand ImmType = simm5>
+    : VPatBinaryV_VV_WithPolicy<intrinsic, instruction, vtilist>,
+      VPatBinaryV_VX_WithPolicy<intrinsic, instruction, vtilist>,
+      VPatBinaryV_VI_WithPolicy<intrinsic, instruction, vtilist, ImmType>;
+
 multiclass VPatBinaryV_VV_VX<string intrinsic, string instruction,
                              list<VTypeInfo> vtilist>
     : VPatBinaryV_VV<intrinsic, instruction, vtilist>,
@@ -3398,7 +3527,7 @@
 //===----------------------------------------------------------------------===//
 // 12.1. Vector Single-Width Integer Add and Subtract
 //===----------------------------------------------------------------------===//
-defm PseudoVADD        : VPseudoBinaryV_VV_VX_VI;
+defm PseudoVADD        : VPseudoBinaryV_VV_VX_VI_WithPolicy;
 defm PseudoVSUB        : VPseudoBinaryV_VV_VX;
 defm PseudoVRSUB       : VPseudoBinaryV_VX_VI;
 
@@ -3446,7 +3575,7 @@
                                                       (NegImm simm5_plus1:$rs2),
                                                       (vti.Mask V0),
                                                       GPR:$vl,
-                                                      vti.Log2SEW)>;
+                                                      vti.Log2SEW, 1)>;
 }
 
 //===----------------------------------------------------------------------===//
@@ -3958,7 +4087,7 @@
 //===----------------------------------------------------------------------===//
 // 12.1. Vector Single-Width Integer Add and Subtract
 //===----------------------------------------------------------------------===//
-defm : VPatBinaryV_VV_VX_VI<"int_riscv_vadd", "PseudoVADD", AllIntegerVectors>;
+defm : VPatBinaryV_VV_VX_VI_WithPolicy<"int_riscv_vadd", "PseudoVADD", AllIntegerVectors>;
 defm : VPatBinaryV_VV_VX<"int_riscv_vsub", "PseudoVSUB", AllIntegerVectors>;
 defm : VPatBinaryV_VX_VI<"int_riscv_vrsub", "PseudoVRSUB", AllIntegerVectors>;
 
Index: llvm/lib/Target/RISCV/RISCVInstrFormats.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrFormats.td
+++ llvm/lib/Target/RISCV/RISCVInstrFormats.td
@@ -178,6 +178,9 @@
 
   bit HasVLOp = 0;
   let TSFlags{15} = HasVLOp;
+
+  bit HasPolicy = false;
+  let TSFlags{16} = HasPolicy;
 }
 
 // Pseudo instructions
Index: llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
+++ llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp
@@ -365,7 +365,9 @@
 
   RISCVII::VLMUL VLMul = RISCVII::getLMul(TSFlags);
 
-  unsigned Log2SEW = MI.getOperand(NumOperands - 1).getImm();
+  unsigned Log2SEWIndex =
+      RISCVII::hasPolicy(TSFlags) ? NumOperands - 2 : NumOperands - 1;
+  unsigned Log2SEW = MI.getOperand(Log2SEWIndex).getImm();
   // A Log2SEW of 0 is an operation on mask registers only.
   bool MaskRegOp = Log2SEW == 0;
   unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
@@ -393,6 +395,12 @@
     }
   }
 
+  // If the instruction has policy argument, use the argument.
+  if (RISCVII::hasPolicy(TSFlags)) {
+    const MachineOperand &Op = MI.getOperand(NumOperands - 1);
+    TailAgnostic = Op.getImm();
+  }
+
   if (RISCVII::hasVLOp(TSFlags)) {
     const MachineOperand &VLOp = MI.getOperand(MI.getNumExplicitOperands() - 2);
     if (VLOp.isImm())
Index: llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
===================================================================
--- llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -76,6 +76,9 @@
   // explicit operand. Used by RVV Pseudos.
   HasVLOpShift = HasSEWOpShift + 1,
   HasVLOpMask = 1 << HasVLOpShift,
+
+  HasPolicyShift = HasVLOpShift + 1,
+  HasPolicyMask = 1 << HasPolicyShift,
 };
 
 // Match with the definitions in RISCVInstrFormatsV.td
@@ -132,6 +135,10 @@
   return TSFlags & HasVLOpMask;
 }
 
+static inline bool hasPolicy(uint64_t TSFlags) {
+  return TSFlags & HasPolicyMask;
+}
+
 // RISC-V Specific Machine Operand Flags
 enum {
   MO_None = 0,
Index: llvm/include/llvm/IR/IntrinsicsRISCV.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsRISCV.td
+++ llvm/include/llvm/IR/IntrinsicsRISCV.td
@@ -338,6 +338,14 @@
                    [IntrNoMem]>, RISCVVIntrinsic {
     let SplatOperand = 3;
   }
+  class RISCVBinaryAAXMaskTA
+       : Intrinsic<[llvm_anyvector_ty],
+                   [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty,
+                    LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_anyint_ty,
+                    LLVMMatchType<2>],
+                   [IntrNoMem]>, RISCVVIntrinsic {
+    let SplatOperand = 3;
+  }
   // For destination vector type is the same as first source vector. The
   // second source operand must match the destination type or be an XLen scalar.
   // Input: (vector_in, vector_in/scalar_in, vl)
@@ -817,6 +825,10 @@
     def "int_riscv_" # NAME : RISCVBinaryAAXNoMask;
     def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMask;
   }
+  multiclass RISCVBinaryAAXWithPolicy {
+    def "int_riscv_" # NAME : RISCVBinaryAAXNoMask;
+    def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMaskTA;
+  }
   // Like RISCVBinaryAAX, but the second operand is used a shift amount so it
   // must be a vector or an XLen scalar.
   multiclass RISCVBinaryAAShift {
@@ -960,7 +972,7 @@
   defm vamominu : RISCVAMO;
   defm vamomaxu : RISCVAMO;
 
-  defm vadd : RISCVBinaryAAX;
+  defm vadd : RISCVBinaryAAXWithPolicy;
   defm vsub : RISCVBinaryAAX;
   defm vrsub : RISCVBinaryAAX;
 
Index: clang/utils/TableGen/RISCVVEmitter.cpp
===================================================================
--- clang/utils/TableGen/RISCVVEmitter.cpp
+++ clang/utils/TableGen/RISCVVEmitter.cpp
@@ -156,6 +156,7 @@
   bool IsMask;
   bool HasMaskedOffOperand;
   bool HasVL;
+  bool HasPolicy;
   bool HasNoMaskedOverloaded;
   bool HasAutoDef; // There is automiatic definition in header
   std::string ManualCodegen;
@@ -169,8 +170,9 @@
 public:
   RVVIntrinsic(StringRef Name, StringRef Suffix, StringRef MangledName,
                StringRef IRName, bool HasSideEffects, bool IsMask,
-               bool HasMaskedOffOperand, bool HasVL, bool HasNoMaskedOverloaded,
-               bool HasAutoDef, StringRef ManualCodegen, const RVVTypes &Types,
+               bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
+               bool HasNoMaskedOverloaded, bool HasAutoDef,
+               StringRef ManualCodegen, const RVVTypes &Types,
                const std::vector<int64_t> &IntrinsicTypes,
                StringRef RequiredExtension);
   ~RVVIntrinsic() = default;
@@ -180,6 +182,7 @@
   bool hasSideEffects() const { return HasSideEffects; }
   bool hasMaskedOffOperand() const { return HasMaskedOffOperand; }
   bool hasVL() const { return HasVL; }
+  bool hasPolicy() const { return HasPolicy; }
   bool hasNoMaskedOverloaded() const { return HasNoMaskedOverloaded; }
   bool hasManualCodegen() const { return !ManualCodegen.empty(); }
   bool hasAutoDef() const { return HasAutoDef; }
@@ -195,6 +198,9 @@
   // init the RVVIntrinsic ID and IntrinsicTypes.
   void emitCodeGenSwitchBody(raw_ostream &o) const;
 
+  // Emit the define macors for mask intrinsics using _mt intrinsics.
+  void emitIntrinsicMaskMacro(raw_ostream &o) const;
+
   // Emit the macros for mapping C/C++ intrinsic function to builtin functions.
   void emitIntrinsicMacro(raw_ostream &o) const;
 
@@ -227,6 +233,8 @@
 private:
   /// Create all intrinsics and add them to \p Out
   void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
+  /// Create Headers and add them to \p Out
+  void createRVVHeaders(raw_ostream &OS);
   /// Compute output and input types by applying different config (basic type
   /// and LMUL with type transformers). It also record result of type in legal
   /// or illegal set to avoid compute the  same config again. The result maybe
@@ -631,7 +639,7 @@
     ScalarType = ScalarTypeKind::SignedLong;
     break;
   default:
-    PrintFatalError("Illegal primitive type transformers!");
+    PrintFatalError("Illegal primitive type transformers: " + PType);
   }
   Transformer = Transformer.drop_back();
 
@@ -745,15 +753,15 @@
 RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
                            StringRef NewMangledName, StringRef IRName,
                            bool HasSideEffects, bool IsMask,
-                           bool HasMaskedOffOperand, bool HasVL,
+                           bool HasMaskedOffOperand, bool HasVL, bool HasPolicy,
                            bool HasNoMaskedOverloaded, bool HasAutoDef,
                            StringRef ManualCodegen, const RVVTypes &OutInTypes,
                            const std::vector<int64_t> &NewIntrinsicTypes,
                            StringRef RequiredExtension)
     : IRName(IRName), HasSideEffects(HasSideEffects), IsMask(IsMask),
       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL),
-      HasNoMaskedOverloaded(HasNoMaskedOverloaded), HasAutoDef(HasAutoDef),
-      ManualCodegen(ManualCodegen.str()) {
+      HasPolicy(HasPolicy), HasNoMaskedOverloaded(HasNoMaskedOverloaded),
+      HasAutoDef(HasAutoDef), ManualCodegen(ManualCodegen.str()) {
 
   // Init Name and MangledName
   Name = NewName.str();
@@ -765,6 +773,8 @@
     Name += "_" + Suffix.str();
   if (IsMask) {
     Name += "_m";
+    if (HasPolicy)
+      Name += "t";
   }
   // Init RISC-V extensions
   for (const auto &T : OutInTypes) {
@@ -813,7 +823,10 @@
 
   if (isMask()) {
     if (hasVL()) {
-      OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
+      if (hasPolicy())
+        OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 2);\n";
+      else
+        OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end() - 1);\n";
     } else {
       OS << "  std::rotate(Ops.begin(), Ops.begin() + 1, Ops.end());\n";
     }
@@ -853,6 +866,24 @@
   OS << ")\n";
 }
 
+void RVVIntrinsic::emitIntrinsicMaskMacro(raw_ostream &OS) const {
+  OS << "#define " << getName().drop_back() << "(";
+  if (!InputTypes.empty()) {
+    ListSeparator LS;
+    for (unsigned i = 0, e = InputTypes.size() - 1; i != e; ++i)
+      OS << LS << "op" << i;
+  }
+  OS << ") \\\n";
+  OS << "__builtin_rvv_" << getName() << "(";
+  ListSeparator LS;
+  if (!InputTypes.empty()) {
+    for (unsigned i = 0, e = InputTypes.size() - 1; i != e; ++i)
+      OS << LS << "(" << InputTypes[i]->getTypeStr() << ")(op" << i << ")";
+  }
+  OS << LS << "(size_t)VE_TAIL_AGNOSTIC";
+  OS << ")\n";
+}
+
 void RVVIntrinsic::emitMangledFuncDef(raw_ostream &OS) const {
   OS << "__attribute__((clang_builtin_alias(";
   OS << "__builtin_rvv_" << getName() << ")))\n";
@@ -898,6 +929,8 @@
   OS << "extern \"C\" {\n";
   OS << "#endif\n\n";
 
+  createRVVHeaders(OS);
+
   std::vector<std::unique_ptr<RVVIntrinsic>> Defs;
   createRVVIntrinsics(Defs);
 
@@ -965,6 +998,12 @@
     Inst.emitIntrinsicMacro(OS);
   });
 
+  // Use _mt to implement _m intrinsics.
+  emitArchMacroAndBody(Defs, OS, [](raw_ostream &OS, const RVVIntrinsic &Inst) {
+    if (Inst.isMask() && Inst.hasPolicy())
+      Inst.emitIntrinsicMaskMacro(OS);
+  });
+
   OS << "#define __riscv_v_intrinsic_overloading 1\n";
 
   // Print Overloaded APIs
@@ -1066,6 +1105,7 @@
     bool HasMask = R->getValueAsBit("HasMask");
     bool HasMaskedOffOperand = R->getValueAsBit("HasMaskedOffOperand");
     bool HasVL = R->getValueAsBit("HasVL");
+    bool HasPolicy = R->getValueAsBit("HasPolicy");
     bool HasNoMaskedOverloaded = R->getValueAsBit("HasNoMaskedOverloaded");
     bool HasSideEffects = R->getValueAsBit("HasSideEffects");
     std::vector<int64_t> Log2LMULList = R->getValueAsListOfInts("Log2LMUL");
@@ -1104,6 +1144,10 @@
       ProtoMaskSeq.push_back("z");
     }
 
+    if (HasPolicy) {
+      ProtoMaskSeq.push_back("z");
+    }
+
     // Create Intrinsics for each type and LMUL.
     for (char I : TypeRange) {
       for (int Log2LMUL : Log2LMULList) {
@@ -1116,7 +1160,7 @@
         // Create a non-mask intrinsic
         Out.push_back(std::make_unique<RVVIntrinsic>(
             Name, SuffixStr, MangledName, IRName, HasSideEffects,
-            /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL,
+            /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, HasPolicy,
             HasNoMaskedOverloaded, HasAutoDef, ManualCodegen, Types.getValue(),
             IntrinsicTypes, RequiredExtension));
         if (HasMask) {
@@ -1125,7 +1169,7 @@
               computeTypes(I, Log2LMUL, ProtoMaskSeq);
           Out.push_back(std::make_unique<RVVIntrinsic>(
               Name, SuffixStr, MangledName, IRNameMask, HasSideEffects,
-              /*IsMask=*/true, HasMaskedOffOperand, HasVL,
+              /*IsMask=*/true, HasMaskedOffOperand, HasVL, HasPolicy,
               HasNoMaskedOverloaded, HasAutoDef, ManualCodegenMask,
               MaskTypes.getValue(), IntrinsicTypes, RequiredExtension));
         }
@@ -1134,6 +1178,15 @@
   }
 }
 
+void RVVEmitter::createRVVHeaders(raw_ostream &OS) {
+  std::vector<Record *> RVVHeaders =
+      Records.getAllDerivedDefinitions("RVVHeader");
+  for (auto *R : RVVHeaders) {
+    StringRef HeaderCodeStr = R->getValueAsString("HeaderCode");
+    OS << HeaderCodeStr.str();
+  }
+}
+
 Optional<RVVTypes>
 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL,
                          ArrayRef<std::string> PrototypeSeq) {
Index: clang/test/CodeGen/RISCV/rvv-intrinsics/vadd-policy.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/RISCV/rvv-intrinsics/vadd-policy.c
@@ -0,0 +1,44 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +experimental-v \
+// RUN:   -disable-O0-optnone -emit-llvm %s -o - | opt -S -mem2reg \
+// RUN:   | FileCheck --check-prefix=CHECK-RV64 %s
+
+#include <riscv_vector.h>
+
+
+// CHECK-RV64-LABEL: @test_vadd_vv_i8m1(
+// CHECK-RV64-NEXT:  entry:
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x i8> @llvm.riscv.vadd.nxv8i8.nxv8i8.i64(<vscale x 8 x i8> [[OP1:%.*]], <vscale x 8 x i8> [[OP2:%.*]], i64 [[VL:%.*]])
+// CHECK-RV64-NEXT:    ret <vscale x 8 x i8> [[TMP0]]
+//
+vint8m1_t test_vadd_vv_i8m1 (vint8m1_t op1, vint8m1_t op2, size_t vl) {
+  return vadd_vv_i8m1(op1, op2, vl);
+}
+
+// CHECK-RV64-LABEL: @test_vadd_vv_i8m1_m(
+// CHECK-RV64-NEXT:  entry:
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8.i64(<vscale x 8 x i8> [[MASKEDOFF:%.*]], <vscale x 8 x i8> [[OP1:%.*]], <vscale x 8 x i8> [[OP2:%.*]], <vscale x 8 x i1> [[MASK:%.*]], i64 [[VL:%.*]], i64 1)
+// CHECK-RV64-NEXT:    ret <vscale x 8 x i8> [[TMP0]]
+//
+vint8m1_t test_vadd_vv_i8m1_m (vbool8_t mask, vint8m1_t maskedoff, vint8m1_t op1, vint8m1_t op2, size_t vl) {
+  return vadd_vv_i8m1_m(mask, maskedoff, op1, op2, vl);
+}
+
+// CHECK-RV64-LABEL: @test_vadd_tu(
+// CHECK-RV64-NEXT:  entry:
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8.i64(<vscale x 8 x i8> [[MASKEDOFF:%.*]], <vscale x 8 x i8> [[OP1:%.*]], <vscale x 8 x i8> [[OP2:%.*]], <vscale x 8 x i1> [[MASK:%.*]], i64 [[VL:%.*]], i64 0)
+// CHECK-RV64-NEXT:    ret <vscale x 8 x i8> [[TMP0]]
+//
+vint8m1_t test_vadd_tu (vbool8_t mask, vint8m1_t maskedoff, vint8m1_t op1, vint8m1_t op2, size_t vl) {
+  return vadd_vv_i8m1_mt(mask, maskedoff, op1, op2, vl, VE_TAIL_UNDISTURBED);
+}
+
+// CHECK-RV64-LABEL: @test_vadd_ta(
+// CHECK-RV64-NEXT:  entry:
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x i8> @llvm.riscv.vadd.mask.nxv8i8.nxv8i8.i64(<vscale x 8 x i8> [[MASKEDOFF:%.*]], <vscale x 8 x i8> [[OP1:%.*]], <vscale x 8 x i8> [[OP2:%.*]], <vscale x 8 x i1> [[MASK:%.*]], i64 [[VL:%.*]], i64 1)
+// CHECK-RV64-NEXT:    ret <vscale x 8 x i8> [[TMP0]]
+//
+vint8m1_t test_vadd_ta (vbool8_t mask, vint8m1_t maskedoff, vint8m1_t op1, vint8m1_t op2, size_t vl) {
+  return vadd_vv_i8m1_mt(mask, maskedoff, op1, op2, vl, VE_TAIL_AGNOSTIC);
+}
Index: clang/include/clang/Basic/riscv_vector.td
===================================================================
--- clang/include/clang/Basic/riscv_vector.td
+++ clang/include/clang/Basic/riscv_vector.td
@@ -169,6 +169,13 @@
   // This builtin has a granted vector length parameter in the last position.
   bit HasVL = true;
 
+  // Normally, intrinsics have the policy argument if it is masked and
+  // have no policy argument if it is unmasked. When HasPolicy is false, it
+  // means the intrinsic has no policy argument regardless masked or unmasked.
+  // For example, when the output result is mask type or scalar type, there is
+  // no need to specify the policy.
+  bit HasPolicy = true;
+
   // This builtin supports non-masked function overloading api.
   // All masked operations support overloading api.
   bit HasNoMaskedOverloaded = true;
@@ -1647,3 +1654,14 @@
     }
   }
 }
+
+class RVVHeader
+{
+  code HeaderCode;
+}
+
+let HeaderCode = [{
+#define VE_TAIL_UNDISTURBED 0
+#define VE_TAIL_AGNOSTIC 1
+}] in
+def policy : RVVHeader;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to