eopXD created this revision.
Herald added subscribers: jobnoorman, luke, VincentWu, vkmr, frasercrmck, 
luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, 
PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, zzheng, jrtc27, 
shiva0217, kito-cheng, niosHD, sabuasal, simoncook, johnrusso, rbar, asb, 
hiraditya, arichardson.
Herald added a project: All.
eopXD requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, pcwang-thead, MaskRay.
Herald added projects: clang, LLVM.

Depends on D152889 <https://reviews.llvm.org/D152889>.

Additional data member is added under RISCVInstrFormats to distinguish
between vector fixed-point and vector floating-point instructions.
Additional data member is added under RVVIntrinsic to distinguish vfadd
with rounding mode control operand.

The value to indicate no rounding mode control is changed from 4 to 99.
As frm has a value range of [0, 4].


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D152996

Files:
  clang/include/clang/Basic/riscv_vector.td
  clang/include/clang/Basic/riscv_vector_common.td
  clang/include/clang/Support/RISCVVIntrinsicUtils.h
  clang/lib/Sema/SemaRISCVVectorLookup.cpp
  clang/lib/Support/RISCVVIntrinsicUtils.cpp
  clang/utils/TableGen/RISCVVEmitter.cpp
  llvm/include/llvm/IR/IntrinsicsRISCV.td
  llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
  llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
  llvm/lib/Target/RISCV/RISCVInstrFormats.td
  llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
  llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
  llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Index: llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -611,7 +611,9 @@
                    op1_reg_class:$rs1,
                    op2_reg_class:$rs2,
                    (mask_type V0),
-                   (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR
+                   // Value to indicate no rounding mode change in
+                   // RISCVInertReadWriteCSR
+                   (XLenVT 99),
                    GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
 
@@ -706,7 +708,9 @@
                    vop_reg_class:$rs1,
                    xop_kind:$rs2,
                    (mask_type V0),
-                   (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR
+                   // Value to indicate no rounding mode change in
+                   // RISCVInertReadWriteCSR
+                   (XLenVT 99),
                    GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
 
@@ -861,6 +865,36 @@
                    scalar_reg_class:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
+class VPatBinaryVL_VF_RM<SDPatternOperator vop,
+                      string instruction_name,
+                      ValueType result_type,
+                      ValueType vop1_type,
+                      ValueType vop2_type,
+                      ValueType mask_type,
+                      int log2sew,
+                      LMULInfo vlmul,
+                      VReg result_reg_class,
+                      VReg vop_reg_class,
+                      RegisterClass scalar_reg_class,
+                      bit isSEWAware = 0>
+    : Pat<(result_type (vop (vop1_type vop_reg_class:$rs1),
+                       (vop2_type (SplatFPOp scalar_reg_class:$rs2)),
+                       (result_type result_reg_class:$merge),
+                       (mask_type V0),
+                       VLOpFrag)),
+      (!cast<Instruction>(
+                   !if(isSEWAware,
+                       instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew)#"_MASK",
+                       instruction_name#"_"#vlmul.MX#"_MASK"))
+                   result_reg_class:$merge,
+                   vop_reg_class:$rs1,
+                   scalar_reg_class:$rs2,
+                   (mask_type V0),
+                   // Value to indicate no rounding mode change in
+                   // RISCVInertReadWriteCSR
+                   (XLenVT 99),
+                   GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
+
 multiclass VPatBinaryFPVL_VV_VF<SDPatternOperator vop, string instruction_name,
                                 bit isSEWAware = 0> {
   foreach vti = AllFloatVectors in {
@@ -877,6 +911,22 @@
   }
 }
 
+multiclass VPatBinaryFPVL_VV_VF_RM<SDPatternOperator vop, string instruction_name,
+                                bit isSEWAware = 0> {
+  foreach vti = AllFloatVectors in {
+    let Predicates = GetVTypePredicates<vti>.Predicates in {
+      def : VPatBinaryVL_V_RM<vop, instruction_name, "VV",
+                             vti.Vector, vti.Vector, vti.Vector, vti.Mask,
+                             vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
+                             vti.RegClass, isSEWAware>;
+      def : VPatBinaryVL_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
+                               vti.Vector, vti.Vector, vti.Vector, vti.Mask,
+                               vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass,
+                               vti.ScalarRegClass, isSEWAware>;
+      }
+  }
+}
+
 multiclass VPatBinaryFPVL_R_VF<SDPatternOperator vop, string instruction_name,
                                bit isSEWAware = 0> {
   foreach fvti = AllFloatVectors in {
@@ -1897,7 +1947,7 @@
 // 13. Vector Floating-Point Instructions
 
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
-defm : VPatBinaryFPVL_VV_VF<any_riscv_fadd_vl, "PseudoVFADD">;
+defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD">;
 defm : VPatBinaryFPVL_VV_VF<any_riscv_fsub_vl, "PseudoVFSUB">;
 defm : VPatBinaryFPVL_R_VF<any_riscv_fsub_vl, "PseudoVFRSUB">;
 
Index: llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -110,7 +110,9 @@
                          instruction_name#"_VV_"# vlmul.MX))
                      op_reg_class:$rs1,
                      op_reg_class:$rs2,
-                     (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR
+                     // Value to indicate no rounding mode change in
+                     // RISCVInertReadWriteCSR
+                     (XLenVT 99),
                      avl, log2sew)>;
 
 class VPatBinarySDNode_XI<SDPatternOperator vop,
@@ -157,7 +159,9 @@
                          instruction_name#_#suffix#_# vlmul.MX))
                      vop_reg_class:$rs1,
                      xop_kind:$rs2,
-                     (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR
+                     // Value to indicate no rounding mode change in
+                     // RISCVInertReadWriteCSR
+                     (XLenVT 99),
                      avl, log2sew)>;
 
 multiclass VPatBinarySDNode_VV_VX<SDPatternOperator vop, string instruction_name,
@@ -239,6 +243,30 @@
                      (xop_type xop_kind:$rs2),
                      avl, log2sew)>;
 
+class VPatBinarySDNode_VF_RM<SDPatternOperator vop,
+                             string instruction_name,
+                             ValueType result_type,
+                             ValueType vop_type,
+                             ValueType xop_type,
+                             int log2sew,
+                             LMULInfo vlmul,
+                             OutPatFrag avl,
+                             VReg vop_reg_class,
+                             DAGOperand xop_kind,
+                             bit isSEWAware = 0> :
+    Pat<(result_type (vop (vop_type vop_reg_class:$rs1),
+                          (vop_type (SplatFPOp xop_kind:$rs2)))),
+        (!cast<Instruction>(
+                     !if(isSEWAware,
+                         instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew),
+                         instruction_name#"_"#vlmul.MX))
+                     vop_reg_class:$rs1,
+                     (xop_type xop_kind:$rs2),
+                     // Value to indicate no rounding mode change in
+                     // RISCVInertReadWriteCSR
+                     (XLenVT 99),
+                     avl, log2sew)>;
+
 multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_name,
                                     bit isSEWAware = 0> {
   foreach vti = AllFloatVectors in {
@@ -254,6 +282,21 @@
   }
 }
 
+multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name,
+                                       bit isSEWAware = 0> {
+  foreach vti = AllFloatVectors in {
+    let Predicates = GetVTypePredicates<vti>.Predicates in {
+      def : VPatBinarySDNode_VV_RM<vop, instruction_name,
+                                   vti.Vector, vti.Vector, vti.Log2SEW,
+                                   vti.LMul, vti.AVL, vti.RegClass, isSEWAware>;
+      def : VPatBinarySDNode_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix,
+                                   vti.Vector, vti.Vector, vti.Scalar,
+                                   vti.Log2SEW, vti.LMul, vti.AVL, vti.RegClass,
+                                   vti.ScalarRegClass, isSEWAware>;
+    }
+  }
+}
+
 multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_name,
                                    bit isSEWAware = 0> {
   foreach fvti = AllFloatVectors in
@@ -993,7 +1036,7 @@
 // 13. Vector Floating-Point Instructions
 
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
-defm : VPatBinaryFPSDNode_VV_VF<any_fadd, "PseudoVFADD">;
+defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD">;
 defm : VPatBinaryFPSDNode_VV_VF<any_fsub, "PseudoVFSUB">;
 defm : VPatBinaryFPSDNode_R_VF<any_fsub, "PseudoVFRSUB">;
 
Index: llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -1168,7 +1168,8 @@
                                       VReg Op1Class,
                                       DAGOperand Op2Class,
                                       string Constraint,
-                                      int DummyMask = 1> :
+                                      int DummyMask = 1,
+                                      int RVVFixedPoint = 1> :
         Pseudo<(outs RetClass:$rd),
                (ins Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm, AVL:$vl, ixlenimm:$sew), []>,
         RISCVVPseudo {
@@ -1180,12 +1181,14 @@
   let HasSEWOp = 1;
   let HasDummyMask = DummyMask;
   let HasRoundModeOp = 1;
+  let IsRVVFixedPoint = RVVFixedPoint;
 }
 
 class VPseudoBinaryNoMaskTURoundingMode<VReg RetClass,
                                         VReg Op1Class,
                                         DAGOperand Op2Class,
-                                        string Constraint> :
+                                        string Constraint,
+                                        int RVVFixedPoint> :
         Pseudo<(outs RetClass:$rd),
                (ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm,
                     AVL:$vl, ixlenimm:$sew), []>,
@@ -1199,12 +1202,14 @@
   let HasDummyMask = 1;
   let HasMergeOp = 1;
   let HasRoundModeOp = 1;
+  let IsRVVFixedPoint = RVVFixedPoint;
 }
 
 class VPseudoBinaryMaskPolicyRoundingMode<VReg RetClass,
                                           RegisterClass Op1Class,
                                           DAGOperand Op2Class,
-                                          string Constraint> :
+                                          string Constraint,
+                                          int RVVFixedPoint> :
         Pseudo<(outs GetVRegNoV0<RetClass>.R:$rd),
                 (ins GetVRegNoV0<RetClass>.R:$merge,
                      Op1Class:$rs2, Op2Class:$rs1,
@@ -1221,6 +1226,7 @@
   let HasVecPolicyOp = 1;
   let UsesMaskPolicy = 1;
   let HasRoundModeOp = 1;
+  let IsRVVFixedPoint = RVVFixedPoint;
 }
 
 // Special version of VPseudoBinaryNoMask where we pretend the first source is
@@ -2036,16 +2042,18 @@
                                      VReg Op1Class,
                                      DAGOperand Op2Class,
                                      LMULInfo MInfo,
-                                     string Constraint = ""> {
+                                     string Constraint = "",
+                                     int IsRVVFixedPoint = 1> {
   let VLMul = MInfo.value in {
     def "_" # MInfo.MX : 
-      VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class, Constraint>;
+      VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class, Constraint,
+                                      /*DummyMask = */ 1, IsRVVFixedPoint>;
     def "_" # MInfo.MX # "_TU" :
       VPseudoBinaryNoMaskTURoundingMode<RetClass, Op1Class, Op2Class,
-                                        Constraint>;
+                                        Constraint, IsRVVFixedPoint>;
     def "_" # MInfo.MX # "_MASK" :
       VPseudoBinaryMaskPolicyRoundingMode<RetClass, Op1Class, Op2Class,
-                                          Constraint>,
+                                          Constraint, IsRVVFixedPoint>,
                                           RISCVMaskedPseudo</*MaskOpIdx*/ 3>;
   }
 }
@@ -2109,6 +2117,11 @@
   defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew>;
 }
 
+multiclass VPseudoBinaryFV_VV_RM<LMULInfo m, string Constraint = ""> {
+  defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint,
+                                       /*IsRVVFixedPoint=*/ 0>;
+}
+
 multiclass VPseudoVGTR_VV_EEW<int eew, string Constraint = ""> {
   foreach m = MxList in {
     defvar mx = m.MX;
@@ -2157,6 +2170,12 @@
                                    f.fprclass, m, Constraint, sew>;
 }
 
+multiclass VPseudoBinaryV_VF_RM<LMULInfo m, FPR_Info f, string Constraint = ""> {
+  defm "_V" # f.FX : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass,
+                                               f.fprclass, m, Constraint,
+                                               /*IsRVVFixedPoint = */ 0>;
+}
+
 multiclass VPseudoVSLD1_VF<string Constraint = ""> {
   foreach f = FPList in {
     foreach m = f.MxList in {
@@ -2891,6 +2910,28 @@
   }
 }
 
+multiclass VPseudoVALU_VV_VF_RM {
+  foreach m = MxListF in {
+    defvar mx = m.MX;
+    defvar WriteVFALUV_MX = !cast<SchedWrite>("WriteVFALUV_" # mx);
+    defvar ReadVFALUV_MX = !cast<SchedRead>("ReadVFALUV_" # mx);
+
+    defm "" : VPseudoBinaryFV_VV_RM<m>,
+              Sched<[WriteVFALUV_MX, ReadVFALUV_MX, ReadVFALUV_MX, ReadVMask]>;
+  }
+
+  foreach f = FPList in {
+    foreach m = f.MxList in {
+      defvar mx = m.MX;
+      defvar WriteVFALUF_MX = !cast<SchedWrite>("WriteVFALUF_" # mx);
+      defvar ReadVFALUV_MX = !cast<SchedRead>("ReadVFALUV_" # mx);
+      defvar ReadVFALUF_MX = !cast<SchedRead>("ReadVFALUF_" # mx);
+      defm "" : VPseudoBinaryV_VF_RM<m, f>,
+                Sched<[WriteVFALUF_MX, ReadVFALUV_MX, ReadVFALUF_MX, ReadVMask]>;
+    }
+  }
+}
+
 multiclass VPseudoVALU_VF {
   foreach f = FPList in {
     foreach m = f.MxList in {
@@ -6008,7 +6049,7 @@
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
 //===----------------------------------------------------------------------===//
 let Uses = [FRM], mayRaiseFPException = true in {
-defm PseudoVFADD  : VPseudoVALU_VV_VF;
+defm PseudoVFADD  : VPseudoVALU_VV_VF_RM;
 defm PseudoVFSUB  : VPseudoVALU_VV_VF;
 defm PseudoVFRSUB : VPseudoVALU_VF;
 }
@@ -6681,7 +6722,8 @@
 //===----------------------------------------------------------------------===//
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
 //===----------------------------------------------------------------------===//
-defm : VPatBinaryV_VV_VX<"int_riscv_vfadd", "PseudoVFADD", AllFloatVectors>;
+defm : VPatBinaryV_VV_VXRoundingMode<"int_riscv_vfadd", "PseudoVFADD",
+                                     AllFloatVectors>;
 defm : VPatBinaryV_VV_VX<"int_riscv_vfsub", "PseudoVFSUB", AllFloatVectors>;
 defm : VPatBinaryV_VX<"int_riscv_vfrsub", "PseudoVFRSUB", AllFloatVectors>;
 
Index: llvm/lib/Target/RISCV/RISCVInstrFormats.td
===================================================================
--- llvm/lib/Target/RISCV/RISCVInstrFormats.td
+++ llvm/lib/Target/RISCV/RISCVInstrFormats.td
@@ -220,6 +220,14 @@
 
   bit HasRoundModeOp = 0;
   let TSFlags{20} =  HasRoundModeOp;
+
+  // This is only valid when HasRoundModeOp is set to 1. HasRoundModeOp is set
+  // to 1 for vector fixed-point or floating-point intrinsics. This bit is
+  // processed under pass 'RISCVInsertReadWriteCSR' pass to distinguish between
+  // fixed-point / floating-point instructions and emit appropriate read/write
+  // to the correct CSR.
+  bit IsRVVFixedPoint = 0;
+  let TSFlags{21} =  IsRVVFixedPoint;
 }
 
 // Pseudo instructions
Index: llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
+++ llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp
@@ -13,6 +13,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "MCTargetDesc/RISCVBaseInfo.h"
+#include "MCTargetDesc/RISCVMCTargetDesc.h"
 #include "RISCV.h"
 #include "RISCVSubtarget.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
@@ -45,7 +47,7 @@
   }
 
 private:
-  bool emitWriteVXRM(MachineBasicBlock &MBB);
+  bool emitWriteRoundingMode(MachineBasicBlock &MBB);
   std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI);
 };
 
@@ -74,22 +76,38 @@
 
 // This function inserts a write to vxrm when encountering an RVV fixed-point
 // instruction.
-bool RISCVInsertReadWriteCSR::emitWriteVXRM(MachineBasicBlock &MBB) {
+bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
   bool Changed = false;
   for (MachineInstr &MI : MBB) {
     if (auto RoundModeIdx = getRoundModeIdx(MI)) {
       Changed = true;
-
-      unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm();
-
-      // The value '4' is a hint to this pass to not alter the vxrm value.
-      if (VXRMImm == 4)
-        continue;
-
-      BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
-          .addImm(VXRMImm);
-      MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
-                                              /*IsImp*/ true));
+      if (RISCVII::isRVVFixedPoint(MI.getDesc().TSFlags)) {
+        unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm();
+
+        // The value '99' is a hint to this pass to not alter the vxrm value.
+        if (VXRMImm == 99)
+          continue;
+
+        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
+            .addImm(VXRMImm);
+        MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
+                                                /*IsImp*/ true));
+        // BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
+        //     .addImm(VXRMImm);
+        // MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
+        //                                         /*IsImp*/ true));
+      } else { // FRM
+        unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm();
+
+        // The value '99' is a hint to this pass to not alter the frm value.
+        if (FRMImm == 99)
+          continue;
+
+        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
+            .addImm(FRMImm);
+        MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
+                                                /*IsImp*/ true));
+      }
     }
   }
   return Changed;
@@ -106,7 +124,7 @@
   bool Changed = false;
 
   for (MachineBasicBlock &MBB : MF)
-    Changed |= emitWriteVXRM(MBB);
+    Changed |= emitWriteRoundingMode(MBB);
 
   return Changed;
 }
Index: llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
===================================================================
--- llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -115,6 +115,9 @@
 
   HasRoundModeOpShift = IsSignExtendingOpWShift + 1,
   HasRoundModeOpMask = 1 << HasRoundModeOpShift,
+
+  IsRVVFixedPointShift = HasRoundModeOpShift + 1,
+  IsRVVFixedPointMask = 1 << IsRVVFixedPointShift,
 };
 
 enum VLMUL : uint8_t {
@@ -181,6 +184,11 @@
   return TSFlags & HasRoundModeOpMask;
 }
 
+/// \returns true if this instruction is a RISC-V Vector fixed-point instruction
+static inline bool isRVVFixedPoint(uint64_t TSFlags) {
+  return TSFlags & IsRVVFixedPointMask;
+}
+
 static inline unsigned getMergeOpNum(const MCInstrDesc &Desc) {
   assert(hasMergeOp(Desc.TSFlags));
   assert(!Desc.isVariadic());
Index: llvm/include/llvm/IR/IntrinsicsRISCV.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsRISCV.td
+++ llvm/include/llvm/IR/IntrinsicsRISCV.td
@@ -420,6 +420,27 @@
     let ScalarOperand = 2;
     let VLOperand = 4;
   }
+  // For destination vector type is the same as first source vector.
+  // Input: (passthru, vector_in, vector_in/scalar_in, frm, vl)
+  class RISCVBinaryAAXUnMaskedRoundingMode
+        : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                    [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty,
+                     llvm_anyint_ty, LLVMMatchType<2>],
+                    [ImmArg<ArgIndex<3>>, IntrNoMem]>, RISCVVIntrinsic {
+    let ScalarOperand = 2;
+    let VLOperand = 4;
+  }
+  // For destination vector type is the same as first source vector (with mask).
+  // Input: (maskedoff, vector_in, vector_in/scalar_in, mask, frm, vl, policy)
+  class RISCVBinaryAAXMaskedRoundingMode
+       : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                   [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty,
+                    LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_anyint_ty,
+                    LLVMMatchType<2>, LLVMMatchType<2>],
+                   [ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, IntrNoMem]>, RISCVVIntrinsic {
+    let ScalarOperand = 2;
+    let VLOperand = 5;
+  }
   // For destination vector type is the same as first source vector. The
   // second source operand must match the destination type or be an XLen scalar.
   // Input: (passthru, vector_in, vector_in/scalar_in, vl)
@@ -1084,6 +1105,10 @@
     def "int_riscv_" # NAME : RISCVBinaryAAXUnMasked;
     def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMasked;
   }
+  multiclass RISCVBinaryAAXRoundingMode {
+    def "int_riscv_" # NAME : RISCVBinaryAAXUnMaskedRoundingMode;
+    def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMaskedRoundingMode;
+  }
   // Like RISCVBinaryAAX, but the second operand is used a shift amount so it
   // must be a vector or an XLen scalar.
   multiclass RISCVBinaryAAShift {
@@ -1292,7 +1317,7 @@
   defm vwmaccus : RISCVTernaryWide;
   defm vwmaccsu : RISCVTernaryWide;
 
-  defm vfadd : RISCVBinaryAAX;
+  defm vfadd : RISCVBinaryAAXRoundingMode;
   defm vfsub : RISCVBinaryAAX;
   defm vfrsub : RISCVBinaryAAX;
 
Index: clang/utils/TableGen/RISCVVEmitter.cpp
===================================================================
--- clang/utils/TableGen/RISCVVEmitter.cpp
+++ clang/utils/TableGen/RISCVVEmitter.cpp
@@ -65,6 +65,7 @@
   bool HasMaskedOffOperand :1;
   bool HasTailPolicy : 1;
   bool HasMaskPolicy : 1;
+  bool HasFRMRoundModeOp : 1;
   bool IsTuple : 1;
   uint8_t UnMaskedPolicyScheme : 2;
   uint8_t MaskedPolicyScheme : 2;
@@ -512,6 +513,7 @@
     StringRef MaskedIRName = R->getValueAsString("MaskedIRName");
     unsigned NF = R->getValueAsInt("NF");
     bool IsTuple = R->getValueAsBit("IsTuple");
+    bool HasFRMRoundModeOp = R->getValueAsBit("HasFRMRoundModeOp");
 
     const Policy DefaultPolicy;
     SmallVector<Policy> SupportedUnMaskedPolicies =
@@ -559,7 +561,7 @@
             /*IsMasked=*/false, /*HasMaskedOffOperand=*/false, HasVL,
             UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias,
             ManualCodegen, *Types, IntrinsicTypes, RequiredFeatures, NF,
-            DefaultPolicy));
+            DefaultPolicy, HasFRMRoundModeOp));
         if (UnMaskedPolicyScheme != PolicyScheme::SchemeNone)
           for (auto P : SupportedUnMaskedPolicies) {
             SmallVector<PrototypeDescriptor> PolicyPrototype =
@@ -574,7 +576,7 @@
                 /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL,
                 UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias,
                 ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures,
-                NF, P));
+                NF, P, HasFRMRoundModeOp));
           }
         if (!HasMasked)
           continue;
@@ -585,7 +587,8 @@
             Name, SuffixStr, OverloadedName, OverloadedSuffixStr, MaskedIRName,
             /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicyScheme,
             SupportOverloading, HasBuiltinAlias, ManualCodegen, *MaskTypes,
-            IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy));
+            IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy,
+            HasFRMRoundModeOp));
         if (MaskedPolicyScheme == PolicyScheme::SchemeNone)
           continue;
         for (auto P : SupportedMaskedPolicies) {
@@ -600,7 +603,7 @@
               MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL,
               MaskedPolicyScheme, SupportOverloading, HasBuiltinAlias,
               ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures, NF,
-              P));
+              P, HasFRMRoundModeOp));
         }
       } // End for Log2LMULList
     }   // End for TypeRange
@@ -653,6 +656,7 @@
     SR.Suffix = parsePrototypes(SuffixProto);
     SR.OverloadedSuffix = parsePrototypes(OverloadedSuffixProto);
     SR.IsTuple = IsTuple;
+    SR.HasFRMRoundModeOp = HasFRMRoundModeOp;
 
     SemaRecords->push_back(SR);
   }
@@ -695,6 +699,7 @@
     R.UnMaskedPolicyScheme = SR.UnMaskedPolicyScheme;
     R.MaskedPolicyScheme = SR.MaskedPolicyScheme;
     R.IsTuple = SR.IsTuple;
+    R.HasFRMRoundModeOp = SR.HasFRMRoundModeOp;
 
     assert(R.PrototypeIndex !=
            static_cast<uint16_t>(SemaSignatureTable::INVALID_INDEX));
Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp
===================================================================
--- clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -870,20 +870,19 @@
 //===----------------------------------------------------------------------===//
 // RVVIntrinsic implementation
 //===----------------------------------------------------------------------===//
-RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
-                           StringRef NewOverloadedName,
-                           StringRef OverloadedSuffix, StringRef IRName,
-                           bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
-                           PolicyScheme Scheme, bool SupportOverloading,
-                           bool HasBuiltinAlias, StringRef ManualCodegen,
-                           const RVVTypes &OutInTypes,
-                           const std::vector<int64_t> &NewIntrinsicTypes,
-                           const std::vector<StringRef> &RequiredFeatures,
-                           unsigned NF, Policy NewPolicyAttrs)
+RVVIntrinsic::RVVIntrinsic(
+    StringRef NewName, StringRef Suffix, StringRef NewOverloadedName,
+    StringRef OverloadedSuffix, StringRef IRName, bool IsMasked,
+    bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme,
+    bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen,
+    const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes,
+    const std::vector<StringRef> &RequiredFeatures, unsigned NF,
+    Policy NewPolicyAttrs, bool HasFRMRoundModeOp)
     : IRName(IRName), IsMasked(IsMasked),
       HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
       SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
-      ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
+      ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs),
+      HasFRMRoundModeOp(HasFRMRoundModeOp) {
 
   // Init BuiltinName, Name and OverloadedName
   BuiltinName = NewName.str();
@@ -898,7 +897,7 @@
     OverloadedName += "_" + OverloadedSuffix.str();
 
   updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
-                       PolicyAttrs);
+                       PolicyAttrs, HasFRMRoundModeOp);
 
   // Init OutputType and InputTypes
   OutputType = OutInTypes[0];
@@ -1023,13 +1022,11 @@
                    "and mask policy");
 }
 
-void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
-                                        std::string &Name,
-                                        std::string &BuiltinName,
-                                        std::string &OverloadedName,
-                                        Policy &PolicyAttrs) {
+void RVVIntrinsic::updateNamesAndPolicy(
+    bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName,
+    std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) {
 
-  auto appendPolicySuffix = [&](const std::string &suffix) {
+  auto appendSuffix = [&](const std::string &suffix) {
     Name += suffix;
     BuiltinName += suffix;
     OverloadedName += suffix;
@@ -1042,11 +1039,11 @@
 
   if (IsMasked) {
     if (PolicyAttrs.isTUMUPolicy())
-      appendPolicySuffix("_tumu");
+      appendSuffix("_tumu");
     else if (PolicyAttrs.isTUMAPolicy())
-      appendPolicySuffix("_tum");
+      appendSuffix("_tum");
     else if (PolicyAttrs.isTAMUPolicy())
-      appendPolicySuffix("_mu");
+      appendSuffix("_mu");
     else if (PolicyAttrs.isTAMAPolicy()) {
       Name += "_m";
       if (HasPolicy)
@@ -1057,13 +1054,16 @@
       llvm_unreachable("Unhandled policy condition");
   } else {
     if (PolicyAttrs.isTUPolicy())
-      appendPolicySuffix("_tu");
+      appendSuffix("_tu");
     else if (PolicyAttrs.isTAPolicy()) {
       if (HasPolicy)
         BuiltinName += "_ta";
     } else
       llvm_unreachable("Unhandled policy condition");
   }
+
+  if (HasFRMRoundModeOp)
+    appendSuffix("_rm");
 }
 
 SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
@@ -1110,6 +1110,7 @@
   OS << (int)Record.HasMaskedOffOperand << ",";
   OS << (int)Record.HasTailPolicy << ",";
   OS << (int)Record.HasMaskPolicy << ",";
+  OS << (int)Record.HasFRMRoundModeOp << ",";
   OS << (int)Record.IsTuple << ",";
   OS << (int)Record.UnMaskedPolicyScheme << ",";
   OS << (int)Record.MaskedPolicyScheme << ",";
Index: clang/lib/Sema/SemaRISCVVectorLookup.cpp
===================================================================
--- clang/lib/Sema/SemaRISCVVectorLookup.cpp
+++ clang/lib/Sema/SemaRISCVVectorLookup.cpp
@@ -349,7 +349,8 @@
   std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
 
   RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName,
-                                     OverloadedName, PolicyAttrs);
+                                     OverloadedName, PolicyAttrs,
+                                     Record.HasFRMRoundModeOp);
 
   // Put into IntrinsicList.
   size_t Index = IntrinsicList.size();
Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h
===================================================================
--- clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -381,6 +381,7 @@
   std::vector<int64_t> IntrinsicTypes;
   unsigned NF = 1;
   Policy PolicyAttrs;
+  bool HasFRMRoundModeOp;
 
 public:
   RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix,
@@ -391,7 +392,7 @@
                const RVVTypes &Types,
                const std::vector<int64_t> &IntrinsicTypes,
                const std::vector<llvm::StringRef> &RequiredFeatures,
-               unsigned NF, Policy PolicyAttrs);
+               unsigned NF, Policy PolicyAttrs, bool HasFRMRoundModeOp);
   ~RVVIntrinsic() = default;
 
   RVVTypePtr getOutputType() const { return OutputType; }
@@ -461,7 +462,7 @@
   static void updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
                                    std::string &Name, std::string &BuiltinName,
                                    std::string &OverloadedName,
-                                   Policy &PolicyAttrs);
+                                   Policy &PolicyAttrs, bool HasFRMRoundModeOp);
 };
 
 // RVVRequire should be sync'ed with target features, but only
@@ -520,6 +521,7 @@
   bool HasMaskedOffOperand : 1;
   bool HasTailPolicy : 1;
   bool HasMaskPolicy : 1;
+  bool HasFRMRoundModeOp : 1;
   bool IsTuple : 1;
   uint8_t UnMaskedPolicyScheme : 2;
   uint8_t MaskedPolicyScheme : 2;
Index: clang/include/clang/Basic/riscv_vector_common.td
===================================================================
--- clang/include/clang/Basic/riscv_vector_common.td
+++ clang/include/clang/Basic/riscv_vector_common.td
@@ -234,6 +234,10 @@
 
   // Set to true if the builtin is associated with tuple types.
   bit IsTuple = false;
+
+  // Set to true if the builtin has a parameter that models floating-point
+  // rounding mode control
+  bit HasFRMRoundModeOp = false;
 }
 
 // This is the code emitted in the header.
Index: clang/include/clang/Basic/riscv_vector.td
===================================================================
--- clang/include/clang/Basic/riscv_vector.td
+++ clang/include/clang/Basic/riscv_vector.td
@@ -226,6 +226,11 @@
                           [["vv", "v", "vvv"],
                            ["vf", "v", "vve"]]>;
 
+multiclass RVVFloatingBinBuiltinSetRoundingMode
+    : RVVOutOp1BuiltinSet<NAME, "xfd",
+                          [["vv", "v", "vvvu"],
+                           ["vf", "v", "vveu"]]>;
+
 multiclass RVVFloatingBinVFBuiltinSet
     : RVVOutOp1BuiltinSet<NAME, "xfd",
                           [["vf", "v", "vve"]]>;
@@ -2206,10 +2211,71 @@
   defm vnclipu : RVVUnsignedNShiftBuiltinSetRoundingMode;
   defm vnclip : RVVSignedNShiftBuiltinSetRoundingMode;
 }
+}
 
 // 14. Vector Floating-Point Instructions
+let HeaderCode =
+[{
+enum __RISCV_FRM {
+  __RISCV_FRM_RNE = 0,
+  __RISCV_FRM_RTZ = 1,
+  __RISCV_FRM_RDN = 2,
+  __RISCV_FRM_RUP = 3,
+  __RISCV_FRM_RMM = 4,
+};
+}] in def frm_enum : RVVHeader;
+
+let UnMaskedPolicyScheme = HasPassthruOperand in {
 // 14.2. Vector Single-Width Floating-Point Add/Subtract Instructions
-defm vfadd  : RVVFloatingBinBuiltinSet;
+let ManualCodegen = [{
+  {
+    // LLVM intrinsic
+    // Unmasked: (passthru, op0, op1, round_mode, vl)
+    // Masked:   (passthru, vector_in, vector_in/scalar_in, mask, frm, vl, policy)
+
+    SmallVector<llvm::Value*, 7> Operands;
+    bool HasMaskedOff = !(
+        (IsMasked && (PolicyAttrs & RVV_VTA) && (PolicyAttrs & RVV_VMA)) ||
+        (!IsMasked && PolicyAttrs & RVV_VTA));
+    bool HasRoundModeOp = IsMasked ?
+      (HasMaskedOff ? Ops.size() == 6 : Ops.size() == 5) :
+      (HasMaskedOff ? Ops.size() == 5 : Ops.size() == 4);
+
+    unsigned Offset = IsMasked ?
+        (HasMaskedOff ? 2 : 1) : (HasMaskedOff ? 1 : 0);
+
+    if (!HasMaskedOff)
+      Operands.push_back(llvm::PoisonValue::get(ResultType));
+    else
+      Operands.push_back(Ops[IsMasked ? 1 : 0]);
+
+    Operands.push_back(Ops[Offset]); // op0
+    Operands.push_back(Ops[Offset + 1]); // op1
+
+    if (IsMasked)
+      Operands.push_back(Ops[0]); // mask
+
+    if (HasRoundModeOp) {
+      Operands.push_back(Ops[Offset + 2]); // frm
+      Operands.push_back(Ops[Offset + 3]); // vl
+    } else {
+      Operands.push_back(ConstantInt::get(Ops[Offset + 2]->getType(), 99)); // frm
+      Operands.push_back(Ops[Offset + 2]); // vl
+    }
+
+    if (IsMasked)
+      Operands.push_back(ConstantInt::get(Ops.back()->getType(), PolicyAttrs));
+
+    IntrinsicTypes = {ResultType, Ops[Offset + 1]->getType(), Ops.back()->getType()};
+    llvm::Function *F = CGM.getIntrinsic(ID, IntrinsicTypes);
+    return Builder.CreateCall(F, Operands, "");
+  }
+}] in {
+  let HasFRMRoundModeOp = true in {
+    defm vfadd  : RVVFloatingBinBuiltinSetRoundingMode;
+  }
+  defm vfadd  : RVVFloatingBinBuiltinSet;
+}
 defm vfsub  : RVVFloatingBinBuiltinSet;
 defm vfrsub : RVVFloatingBinVFBuiltinSet;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to