eopXD created this revision. Herald added subscribers: jobnoorman, luke, VincentWu, vkmr, frasercrmck, luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, zzheng, jrtc27, shiva0217, kito-cheng, niosHD, sabuasal, simoncook, johnrusso, rbar, asb, hiraditya, arichardson. Herald added a project: All. eopXD requested review of this revision. Herald added subscribers: llvm-commits, cfe-commits, pcwang-thead, MaskRay. Herald added projects: clang, LLVM.
Depends on D152889 <https://reviews.llvm.org/D152889>. Additional data member is added under RISCVInstrFormats to distinguish between vector fixed-point and vector floating-point instructions. Additional data member is added under RVVIntrinsic to distinguish vfadd with rounding mode control operand. The value to indicate no rounding mode control is changed from 4 to 99. As frm has a value range of [0, 4]. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D152996 Files: clang/include/clang/Basic/riscv_vector.td clang/include/clang/Basic/riscv_vector_common.td clang/include/clang/Support/RISCVVIntrinsicUtils.h clang/lib/Sema/SemaRISCVVectorLookup.cpp clang/lib/Support/RISCVVIntrinsicUtils.cpp clang/utils/TableGen/RISCVVEmitter.cpp llvm/include/llvm/IR/IntrinsicsRISCV.td llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp llvm/lib/Target/RISCV/RISCVInstrFormats.td llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Index: llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -611,7 +611,9 @@ op1_reg_class:$rs1, op2_reg_class:$rs2, (mask_type V0), - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; @@ -706,7 +708,9 @@ vop_reg_class:$rs1, xop_kind:$rs2, (mask_type V0), - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; @@ -861,6 +865,36 @@ scalar_reg_class:$rs2, (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; +class VPatBinaryVL_VF_RM<SDPatternOperator vop, + string instruction_name, + ValueType result_type, + ValueType vop1_type, + ValueType vop2_type, + ValueType mask_type, + int log2sew, + LMULInfo vlmul, + VReg result_reg_class, + VReg vop_reg_class, + RegisterClass scalar_reg_class, + bit isSEWAware = 0> + : Pat<(result_type (vop (vop1_type vop_reg_class:$rs1), + (vop2_type (SplatFPOp scalar_reg_class:$rs2)), + (result_type result_reg_class:$merge), + (mask_type V0), + VLOpFrag)), + (!cast<Instruction>( + !if(isSEWAware, + instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew)#"_MASK", + instruction_name#"_"#vlmul.MX#"_MASK")) + result_reg_class:$merge, + vop_reg_class:$rs1, + scalar_reg_class:$rs2, + (mask_type V0), + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), + GPR:$vl, log2sew, TAIL_AGNOSTIC)>; + multiclass VPatBinaryFPVL_VV_VF<SDPatternOperator vop, string instruction_name, bit isSEWAware = 0> { foreach vti = AllFloatVectors in { @@ -877,6 +911,22 @@ } } +multiclass VPatBinaryFPVL_VV_VF_RM<SDPatternOperator vop, string instruction_name, + bit isSEWAware = 0> { + foreach vti = AllFloatVectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in { + def : VPatBinaryVL_V_RM<vop, instruction_name, "VV", + vti.Vector, vti.Vector, vti.Vector, vti.Mask, + vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass, + vti.RegClass, isSEWAware>; + def : VPatBinaryVL_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix, + vti.Vector, vti.Vector, vti.Vector, vti.Mask, + vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass, + vti.ScalarRegClass, isSEWAware>; + } + } +} + multiclass VPatBinaryFPVL_R_VF<SDPatternOperator vop, string instruction_name, bit isSEWAware = 0> { foreach fvti = AllFloatVectors in { @@ -1897,7 +1947,7 @@ // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPVL_VV_VF<any_riscv_fadd_vl, "PseudoVFADD">; +defm : VPatBinaryFPVL_VV_VF_RM<any_riscv_fadd_vl, "PseudoVFADD">; defm : VPatBinaryFPVL_VV_VF<any_riscv_fsub_vl, "PseudoVFSUB">; defm : VPatBinaryFPVL_R_VF<any_riscv_fsub_vl, "PseudoVFRSUB">; Index: llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -110,7 +110,9 @@ instruction_name#"_VV_"# vlmul.MX)) op_reg_class:$rs1, op_reg_class:$rs2, - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), avl, log2sew)>; class VPatBinarySDNode_XI<SDPatternOperator vop, @@ -157,7 +159,9 @@ instruction_name#_#suffix#_# vlmul.MX)) vop_reg_class:$rs1, xop_kind:$rs2, - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), avl, log2sew)>; multiclass VPatBinarySDNode_VV_VX<SDPatternOperator vop, string instruction_name, @@ -239,6 +243,30 @@ (xop_type xop_kind:$rs2), avl, log2sew)>; +class VPatBinarySDNode_VF_RM<SDPatternOperator vop, + string instruction_name, + ValueType result_type, + ValueType vop_type, + ValueType xop_type, + int log2sew, + LMULInfo vlmul, + OutPatFrag avl, + VReg vop_reg_class, + DAGOperand xop_kind, + bit isSEWAware = 0> : + Pat<(result_type (vop (vop_type vop_reg_class:$rs1), + (vop_type (SplatFPOp xop_kind:$rs2)))), + (!cast<Instruction>( + !if(isSEWAware, + instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew), + instruction_name#"_"#vlmul.MX)) + vop_reg_class:$rs1, + (xop_type xop_kind:$rs2), + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), + avl, log2sew)>; + multiclass VPatBinaryFPSDNode_VV_VF<SDPatternOperator vop, string instruction_name, bit isSEWAware = 0> { foreach vti = AllFloatVectors in { @@ -254,6 +282,21 @@ } } +multiclass VPatBinaryFPSDNode_VV_VF_RM<SDPatternOperator vop, string instruction_name, + bit isSEWAware = 0> { + foreach vti = AllFloatVectors in { + let Predicates = GetVTypePredicates<vti>.Predicates in { + def : VPatBinarySDNode_VV_RM<vop, instruction_name, + vti.Vector, vti.Vector, vti.Log2SEW, + vti.LMul, vti.AVL, vti.RegClass, isSEWAware>; + def : VPatBinarySDNode_VF_RM<vop, instruction_name#"_V"#vti.ScalarSuffix, + vti.Vector, vti.Vector, vti.Scalar, + vti.Log2SEW, vti.LMul, vti.AVL, vti.RegClass, + vti.ScalarRegClass, isSEWAware>; + } + } +} + multiclass VPatBinaryFPSDNode_R_VF<SDPatternOperator vop, string instruction_name, bit isSEWAware = 0> { foreach fvti = AllFloatVectors in @@ -993,7 +1036,7 @@ // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPSDNode_VV_VF<any_fadd, "PseudoVFADD">; +defm : VPatBinaryFPSDNode_VV_VF_RM<any_fadd, "PseudoVFADD">; defm : VPatBinaryFPSDNode_VV_VF<any_fsub, "PseudoVFSUB">; defm : VPatBinaryFPSDNode_R_VF<any_fsub, "PseudoVFRSUB">; Index: llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -1168,7 +1168,8 @@ VReg Op1Class, DAGOperand Op2Class, string Constraint, - int DummyMask = 1> : + int DummyMask = 1, + int RVVFixedPoint = 1> : Pseudo<(outs RetClass:$rd), (ins Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm, AVL:$vl, ixlenimm:$sew), []>, RISCVVPseudo { @@ -1180,12 +1181,14 @@ let HasSEWOp = 1; let HasDummyMask = DummyMask; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } class VPseudoBinaryNoMaskTURoundingMode<VReg RetClass, VReg Op1Class, DAGOperand Op2Class, - string Constraint> : + string Constraint, + int RVVFixedPoint> : Pseudo<(outs RetClass:$rd), (ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm, AVL:$vl, ixlenimm:$sew), []>, @@ -1199,12 +1202,14 @@ let HasDummyMask = 1; let HasMergeOp = 1; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } class VPseudoBinaryMaskPolicyRoundingMode<VReg RetClass, RegisterClass Op1Class, DAGOperand Op2Class, - string Constraint> : + string Constraint, + int RVVFixedPoint> : Pseudo<(outs GetVRegNoV0<RetClass>.R:$rd), (ins GetVRegNoV0<RetClass>.R:$merge, Op1Class:$rs2, Op2Class:$rs1, @@ -1221,6 +1226,7 @@ let HasVecPolicyOp = 1; let UsesMaskPolicy = 1; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } // Special version of VPseudoBinaryNoMask where we pretend the first source is @@ -2036,16 +2042,18 @@ VReg Op1Class, DAGOperand Op2Class, LMULInfo MInfo, - string Constraint = ""> { + string Constraint = "", + int IsRVVFixedPoint = 1> { let VLMul = MInfo.value in { def "_" # MInfo.MX : - VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class, Constraint>; + VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class, Constraint, + /*DummyMask = */ 1, IsRVVFixedPoint>; def "_" # MInfo.MX # "_TU" : VPseudoBinaryNoMaskTURoundingMode<RetClass, Op1Class, Op2Class, - Constraint>; + Constraint, IsRVVFixedPoint>; def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicyRoundingMode<RetClass, Op1Class, Op2Class, - Constraint>, + Constraint, IsRVVFixedPoint>, RISCVMaskedPseudo</*MaskOpIdx*/ 3>; } } @@ -2109,6 +2117,11 @@ defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew>; } +multiclass VPseudoBinaryFV_VV_RM<LMULInfo m, string Constraint = ""> { + defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint, + /*IsRVVFixedPoint=*/ 0>; +} + multiclass VPseudoVGTR_VV_EEW<int eew, string Constraint = ""> { foreach m = MxList in { defvar mx = m.MX; @@ -2157,6 +2170,12 @@ f.fprclass, m, Constraint, sew>; } +multiclass VPseudoBinaryV_VF_RM<LMULInfo m, FPR_Info f, string Constraint = ""> { + defm "_V" # f.FX : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, + f.fprclass, m, Constraint, + /*IsRVVFixedPoint = */ 0>; +} + multiclass VPseudoVSLD1_VF<string Constraint = ""> { foreach f = FPList in { foreach m = f.MxList in { @@ -2891,6 +2910,28 @@ } } +multiclass VPseudoVALU_VV_VF_RM { + foreach m = MxListF in { + defvar mx = m.MX; + defvar WriteVFALUV_MX = !cast<SchedWrite>("WriteVFALUV_" # mx); + defvar ReadVFALUV_MX = !cast<SchedRead>("ReadVFALUV_" # mx); + + defm "" : VPseudoBinaryFV_VV_RM<m>, + Sched<[WriteVFALUV_MX, ReadVFALUV_MX, ReadVFALUV_MX, ReadVMask]>; + } + + foreach f = FPList in { + foreach m = f.MxList in { + defvar mx = m.MX; + defvar WriteVFALUF_MX = !cast<SchedWrite>("WriteVFALUF_" # mx); + defvar ReadVFALUV_MX = !cast<SchedRead>("ReadVFALUV_" # mx); + defvar ReadVFALUF_MX = !cast<SchedRead>("ReadVFALUF_" # mx); + defm "" : VPseudoBinaryV_VF_RM<m, f>, + Sched<[WriteVFALUF_MX, ReadVFALUV_MX, ReadVFALUF_MX, ReadVMask]>; + } + } +} + multiclass VPseudoVALU_VF { foreach f = FPList in { foreach m = f.MxList in { @@ -6008,7 +6049,7 @@ // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions //===----------------------------------------------------------------------===// let Uses = [FRM], mayRaiseFPException = true in { -defm PseudoVFADD : VPseudoVALU_VV_VF; +defm PseudoVFADD : VPseudoVALU_VV_VF_RM; defm PseudoVFSUB : VPseudoVALU_VV_VF; defm PseudoVFRSUB : VPseudoVALU_VF; } @@ -6681,7 +6722,8 @@ //===----------------------------------------------------------------------===// // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions //===----------------------------------------------------------------------===// -defm : VPatBinaryV_VV_VX<"int_riscv_vfadd", "PseudoVFADD", AllFloatVectors>; +defm : VPatBinaryV_VV_VXRoundingMode<"int_riscv_vfadd", "PseudoVFADD", + AllFloatVectors>; defm : VPatBinaryV_VV_VX<"int_riscv_vfsub", "PseudoVFSUB", AllFloatVectors>; defm : VPatBinaryV_VX<"int_riscv_vfrsub", "PseudoVFRSUB", AllFloatVectors>; Index: llvm/lib/Target/RISCV/RISCVInstrFormats.td =================================================================== --- llvm/lib/Target/RISCV/RISCVInstrFormats.td +++ llvm/lib/Target/RISCV/RISCVInstrFormats.td @@ -220,6 +220,14 @@ bit HasRoundModeOp = 0; let TSFlags{20} = HasRoundModeOp; + + // This is only valid when HasRoundModeOp is set to 1. HasRoundModeOp is set + // to 1 for vector fixed-point or floating-point intrinsics. This bit is + // processed under pass 'RISCVInsertReadWriteCSR' pass to distinguish between + // fixed-point / floating-point instructions and emit appropriate read/write + // to the correct CSR. + bit IsRVVFixedPoint = 0; + let TSFlags{21} = IsRVVFixedPoint; } // Pseudo instructions Index: llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp +++ llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp @@ -13,6 +13,8 @@ // //===----------------------------------------------------------------------===// +#include "MCTargetDesc/RISCVBaseInfo.h" +#include "MCTargetDesc/RISCVMCTargetDesc.h" #include "RISCV.h" #include "RISCVSubtarget.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -45,7 +47,7 @@ } private: - bool emitWriteVXRM(MachineBasicBlock &MBB); + bool emitWriteRoundingMode(MachineBasicBlock &MBB); std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI); }; @@ -74,22 +76,38 @@ // This function inserts a write to vxrm when encountering an RVV fixed-point // instruction. -bool RISCVInsertReadWriteCSR::emitWriteVXRM(MachineBasicBlock &MBB) { +bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { bool Changed = false; for (MachineInstr &MI : MBB) { if (auto RoundModeIdx = getRoundModeIdx(MI)) { Changed = true; - - unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); - - // The value '4' is a hint to this pass to not alter the vxrm value. - if (VXRMImm == 4) - continue; - - BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) - .addImm(VXRMImm); - MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, - /*IsImp*/ true)); + if (RISCVII::isRVVFixedPoint(MI.getDesc().TSFlags)) { + unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); + + // The value '99' is a hint to this pass to not alter the vxrm value. + if (VXRMImm == 99) + continue; + + BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) + .addImm(VXRMImm); + MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, + /*IsImp*/ true)); + // BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) + // .addImm(VXRMImm); + // MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, + // /*IsImp*/ true)); + } else { // FRM + unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm(); + + // The value '99' is a hint to this pass to not alter the frm value. + if (FRMImm == 99) + continue; + + BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) + .addImm(FRMImm); + MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, + /*IsImp*/ true)); + } } } return Changed; @@ -106,7 +124,7 @@ bool Changed = false; for (MachineBasicBlock &MBB : MF) - Changed |= emitWriteVXRM(MBB); + Changed |= emitWriteRoundingMode(MBB); return Changed; } Index: llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h =================================================================== --- llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -115,6 +115,9 @@ HasRoundModeOpShift = IsSignExtendingOpWShift + 1, HasRoundModeOpMask = 1 << HasRoundModeOpShift, + + IsRVVFixedPointShift = HasRoundModeOpShift + 1, + IsRVVFixedPointMask = 1 << IsRVVFixedPointShift, }; enum VLMUL : uint8_t { @@ -181,6 +184,11 @@ return TSFlags & HasRoundModeOpMask; } +/// \returns true if this instruction is a RISC-V Vector fixed-point instruction +static inline bool isRVVFixedPoint(uint64_t TSFlags) { + return TSFlags & IsRVVFixedPointMask; +} + static inline unsigned getMergeOpNum(const MCInstrDesc &Desc) { assert(hasMergeOp(Desc.TSFlags)); assert(!Desc.isVariadic()); Index: llvm/include/llvm/IR/IntrinsicsRISCV.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsRISCV.td +++ llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -420,6 +420,27 @@ let ScalarOperand = 2; let VLOperand = 4; } + // For destination vector type is the same as first source vector. + // Input: (passthru, vector_in, vector_in/scalar_in, frm, vl) + class RISCVBinaryAAXUnMaskedRoundingMode + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty, + llvm_anyint_ty, LLVMMatchType<2>], + [ImmArg<ArgIndex<3>>, IntrNoMem]>, RISCVVIntrinsic { + let ScalarOperand = 2; + let VLOperand = 4; + } + // For destination vector type is the same as first source vector (with mask). + // Input: (maskedoff, vector_in, vector_in/scalar_in, mask, frm, vl, policy) + class RISCVBinaryAAXMaskedRoundingMode + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_anyint_ty, + LLVMMatchType<2>, LLVMMatchType<2>], + [ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, IntrNoMem]>, RISCVVIntrinsic { + let ScalarOperand = 2; + let VLOperand = 5; + } // For destination vector type is the same as first source vector. The // second source operand must match the destination type or be an XLen scalar. // Input: (passthru, vector_in, vector_in/scalar_in, vl) @@ -1084,6 +1105,10 @@ def "int_riscv_" # NAME : RISCVBinaryAAXUnMasked; def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMasked; } + multiclass RISCVBinaryAAXRoundingMode { + def "int_riscv_" # NAME : RISCVBinaryAAXUnMaskedRoundingMode; + def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMaskedRoundingMode; + } // Like RISCVBinaryAAX, but the second operand is used a shift amount so it // must be a vector or an XLen scalar. multiclass RISCVBinaryAAShift { @@ -1292,7 +1317,7 @@ defm vwmaccus : RISCVTernaryWide; defm vwmaccsu : RISCVTernaryWide; - defm vfadd : RISCVBinaryAAX; + defm vfadd : RISCVBinaryAAXRoundingMode; defm vfsub : RISCVBinaryAAX; defm vfrsub : RISCVBinaryAAX; Index: clang/utils/TableGen/RISCVVEmitter.cpp =================================================================== --- clang/utils/TableGen/RISCVVEmitter.cpp +++ clang/utils/TableGen/RISCVVEmitter.cpp @@ -65,6 +65,7 @@ bool HasMaskedOffOperand :1; bool HasTailPolicy : 1; bool HasMaskPolicy : 1; + bool HasFRMRoundModeOp : 1; bool IsTuple : 1; uint8_t UnMaskedPolicyScheme : 2; uint8_t MaskedPolicyScheme : 2; @@ -512,6 +513,7 @@ StringRef MaskedIRName = R->getValueAsString("MaskedIRName"); unsigned NF = R->getValueAsInt("NF"); bool IsTuple = R->getValueAsBit("IsTuple"); + bool HasFRMRoundModeOp = R->getValueAsBit("HasFRMRoundModeOp"); const Policy DefaultPolicy; SmallVector<Policy> SupportedUnMaskedPolicies = @@ -559,7 +561,7 @@ /*IsMasked=*/false, /*HasMaskedOffOperand=*/false, HasVL, UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *Types, IntrinsicTypes, RequiredFeatures, NF, - DefaultPolicy)); + DefaultPolicy, HasFRMRoundModeOp)); if (UnMaskedPolicyScheme != PolicyScheme::SchemeNone) for (auto P : SupportedUnMaskedPolicies) { SmallVector<PrototypeDescriptor> PolicyPrototype = @@ -574,7 +576,7 @@ /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures, - NF, P)); + NF, P, HasFRMRoundModeOp)); } if (!HasMasked) continue; @@ -585,7 +587,8 @@ Name, SuffixStr, OverloadedName, OverloadedSuffixStr, MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *MaskTypes, - IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy)); + IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy, + HasFRMRoundModeOp)); if (MaskedPolicyScheme == PolicyScheme::SchemeNone) continue; for (auto P : SupportedMaskedPolicies) { @@ -600,7 +603,7 @@ MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures, NF, - P)); + P, HasFRMRoundModeOp)); } } // End for Log2LMULList } // End for TypeRange @@ -653,6 +656,7 @@ SR.Suffix = parsePrototypes(SuffixProto); SR.OverloadedSuffix = parsePrototypes(OverloadedSuffixProto); SR.IsTuple = IsTuple; + SR.HasFRMRoundModeOp = HasFRMRoundModeOp; SemaRecords->push_back(SR); } @@ -695,6 +699,7 @@ R.UnMaskedPolicyScheme = SR.UnMaskedPolicyScheme; R.MaskedPolicyScheme = SR.MaskedPolicyScheme; R.IsTuple = SR.IsTuple; + R.HasFRMRoundModeOp = SR.HasFRMRoundModeOp; assert(R.PrototypeIndex != static_cast<uint16_t>(SemaSignatureTable::INVALID_INDEX)); Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp =================================================================== --- clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -870,20 +870,19 @@ //===----------------------------------------------------------------------===// // RVVIntrinsic implementation //===----------------------------------------------------------------------===// -RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, - StringRef NewOverloadedName, - StringRef OverloadedSuffix, StringRef IRName, - bool IsMasked, bool HasMaskedOffOperand, bool HasVL, - PolicyScheme Scheme, bool SupportOverloading, - bool HasBuiltinAlias, StringRef ManualCodegen, - const RVVTypes &OutInTypes, - const std::vector<int64_t> &NewIntrinsicTypes, - const std::vector<StringRef> &RequiredFeatures, - unsigned NF, Policy NewPolicyAttrs) +RVVIntrinsic::RVVIntrinsic( + StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, + StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, + bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, + bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen, + const RVVTypes &OutInTypes, const std::vector<int64_t> &NewIntrinsicTypes, + const std::vector<StringRef> &RequiredFeatures, unsigned NF, + Policy NewPolicyAttrs, bool HasFRMRoundModeOp) : IRName(IRName), IsMasked(IsMasked), HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), - ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { + ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs), + HasFRMRoundModeOp(HasFRMRoundModeOp) { // Init BuiltinName, Name and OverloadedName BuiltinName = NewName.str(); @@ -898,7 +897,7 @@ OverloadedName += "_" + OverloadedSuffix.str(); updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, - PolicyAttrs); + PolicyAttrs, HasFRMRoundModeOp); // Init OutputType and InputTypes OutputType = OutInTypes[0]; @@ -1023,13 +1022,11 @@ "and mask policy"); } -void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy, - std::string &Name, - std::string &BuiltinName, - std::string &OverloadedName, - Policy &PolicyAttrs) { +void RVVIntrinsic::updateNamesAndPolicy( + bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, + std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) { - auto appendPolicySuffix = [&](const std::string &suffix) { + auto appendSuffix = [&](const std::string &suffix) { Name += suffix; BuiltinName += suffix; OverloadedName += suffix; @@ -1042,11 +1039,11 @@ if (IsMasked) { if (PolicyAttrs.isTUMUPolicy()) - appendPolicySuffix("_tumu"); + appendSuffix("_tumu"); else if (PolicyAttrs.isTUMAPolicy()) - appendPolicySuffix("_tum"); + appendSuffix("_tum"); else if (PolicyAttrs.isTAMUPolicy()) - appendPolicySuffix("_mu"); + appendSuffix("_mu"); else if (PolicyAttrs.isTAMAPolicy()) { Name += "_m"; if (HasPolicy) @@ -1057,13 +1054,16 @@ llvm_unreachable("Unhandled policy condition"); } else { if (PolicyAttrs.isTUPolicy()) - appendPolicySuffix("_tu"); + appendSuffix("_tu"); else if (PolicyAttrs.isTAPolicy()) { if (HasPolicy) BuiltinName += "_ta"; } else llvm_unreachable("Unhandled policy condition"); } + + if (HasFRMRoundModeOp) + appendSuffix("_rm"); } SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) { @@ -1110,6 +1110,7 @@ OS << (int)Record.HasMaskedOffOperand << ","; OS << (int)Record.HasTailPolicy << ","; OS << (int)Record.HasMaskPolicy << ","; + OS << (int)Record.HasFRMRoundModeOp << ","; OS << (int)Record.IsTuple << ","; OS << (int)Record.UnMaskedPolicyScheme << ","; OS << (int)Record.MaskedPolicyScheme << ","; Index: clang/lib/Sema/SemaRISCVVectorLookup.cpp =================================================================== --- clang/lib/Sema/SemaRISCVVectorLookup.cpp +++ clang/lib/Sema/SemaRISCVVectorLookup.cpp @@ -349,7 +349,8 @@ std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name); RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName, - OverloadedName, PolicyAttrs); + OverloadedName, PolicyAttrs, + Record.HasFRMRoundModeOp); // Put into IntrinsicList. size_t Index = IntrinsicList.size(); Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h =================================================================== --- clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -381,6 +381,7 @@ std::vector<int64_t> IntrinsicTypes; unsigned NF = 1; Policy PolicyAttrs; + bool HasFRMRoundModeOp; public: RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix, @@ -391,7 +392,7 @@ const RVVTypes &Types, const std::vector<int64_t> &IntrinsicTypes, const std::vector<llvm::StringRef> &RequiredFeatures, - unsigned NF, Policy PolicyAttrs); + unsigned NF, Policy PolicyAttrs, bool HasFRMRoundModeOp); ~RVVIntrinsic() = default; RVVTypePtr getOutputType() const { return OutputType; } @@ -461,7 +462,7 @@ static void updateNamesAndPolicy(bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, std::string &OverloadedName, - Policy &PolicyAttrs); + Policy &PolicyAttrs, bool HasFRMRoundModeOp); }; // RVVRequire should be sync'ed with target features, but only @@ -520,6 +521,7 @@ bool HasMaskedOffOperand : 1; bool HasTailPolicy : 1; bool HasMaskPolicy : 1; + bool HasFRMRoundModeOp : 1; bool IsTuple : 1; uint8_t UnMaskedPolicyScheme : 2; uint8_t MaskedPolicyScheme : 2; Index: clang/include/clang/Basic/riscv_vector_common.td =================================================================== --- clang/include/clang/Basic/riscv_vector_common.td +++ clang/include/clang/Basic/riscv_vector_common.td @@ -234,6 +234,10 @@ // Set to true if the builtin is associated with tuple types. bit IsTuple = false; + + // Set to true if the builtin has a parameter that models floating-point + // rounding mode control + bit HasFRMRoundModeOp = false; } // This is the code emitted in the header. Index: clang/include/clang/Basic/riscv_vector.td =================================================================== --- clang/include/clang/Basic/riscv_vector.td +++ clang/include/clang/Basic/riscv_vector.td @@ -226,6 +226,11 @@ [["vv", "v", "vvv"], ["vf", "v", "vve"]]>; +multiclass RVVFloatingBinBuiltinSetRoundingMode + : RVVOutOp1BuiltinSet<NAME, "xfd", + [["vv", "v", "vvvu"], + ["vf", "v", "vveu"]]>; + multiclass RVVFloatingBinVFBuiltinSet : RVVOutOp1BuiltinSet<NAME, "xfd", [["vf", "v", "vve"]]>; @@ -2206,10 +2211,71 @@ defm vnclipu : RVVUnsignedNShiftBuiltinSetRoundingMode; defm vnclip : RVVSignedNShiftBuiltinSetRoundingMode; } +} // 14. Vector Floating-Point Instructions +let HeaderCode = +[{ +enum __RISCV_FRM { + __RISCV_FRM_RNE = 0, + __RISCV_FRM_RTZ = 1, + __RISCV_FRM_RDN = 2, + __RISCV_FRM_RUP = 3, + __RISCV_FRM_RMM = 4, +}; +}] in def frm_enum : RVVHeader; + +let UnMaskedPolicyScheme = HasPassthruOperand in { // 14.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm vfadd : RVVFloatingBinBuiltinSet; +let ManualCodegen = [{ + { + // LLVM intrinsic + // Unmasked: (passthru, op0, op1, round_mode, vl) + // Masked: (passthru, vector_in, vector_in/scalar_in, mask, frm, vl, policy) + + SmallVector<llvm::Value*, 7> Operands; + bool HasMaskedOff = !( + (IsMasked && (PolicyAttrs & RVV_VTA) && (PolicyAttrs & RVV_VMA)) || + (!IsMasked && PolicyAttrs & RVV_VTA)); + bool HasRoundModeOp = IsMasked ? + (HasMaskedOff ? Ops.size() == 6 : Ops.size() == 5) : + (HasMaskedOff ? Ops.size() == 5 : Ops.size() == 4); + + unsigned Offset = IsMasked ? + (HasMaskedOff ? 2 : 1) : (HasMaskedOff ? 1 : 0); + + if (!HasMaskedOff) + Operands.push_back(llvm::PoisonValue::get(ResultType)); + else + Operands.push_back(Ops[IsMasked ? 1 : 0]); + + Operands.push_back(Ops[Offset]); // op0 + Operands.push_back(Ops[Offset + 1]); // op1 + + if (IsMasked) + Operands.push_back(Ops[0]); // mask + + if (HasRoundModeOp) { + Operands.push_back(Ops[Offset + 2]); // frm + Operands.push_back(Ops[Offset + 3]); // vl + } else { + Operands.push_back(ConstantInt::get(Ops[Offset + 2]->getType(), 99)); // frm + Operands.push_back(Ops[Offset + 2]); // vl + } + + if (IsMasked) + Operands.push_back(ConstantInt::get(Ops.back()->getType(), PolicyAttrs)); + + IntrinsicTypes = {ResultType, Ops[Offset + 1]->getType(), Ops.back()->getType()}; + llvm::Function *F = CGM.getIntrinsic(ID, IntrinsicTypes); + return Builder.CreateCall(F, Operands, ""); + } +}] in { + let HasFRMRoundModeOp = true in { + defm vfadd : RVVFloatingBinBuiltinSetRoundingMode; + } + defm vfadd : RVVFloatingBinBuiltinSet; +} defm vfsub : RVVFloatingBinBuiltinSet; defm vfrsub : RVVFloatingBinVFBuiltinSet;
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits