kushanam updated this revision to Diff 528526. kushanam added a comment. adding min and max for bf16 and refactoring the code
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D144911/new/ https://reviews.llvm.org/D144911 Files: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -150,23 +150,11 @@ } static bool Isv2f16Orv2bf16Type(MVT VT) { - switch (VT.SimpleTy) { - default: - return false; - case MVT::v2f16: - case MVT::v2bf16: - return true; - } + return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16); } static bool Isf16Orbf16Type(MVT VT) { - switch (VT.SimpleTy) { - default: - return false; - case MVT::f16: - case MVT::bf16: - return true; - } + return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16); } /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive @@ -624,9 +612,6 @@ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { setFP16OperationAction(Op, MVT::f16, Legal, Promote); setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); - } - - for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { setBF16OperationAction(Op, MVT::bf16, Legal, Promote); setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); } @@ -693,20 +678,18 @@ }; for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote); + setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); } for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand); setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); - } - for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { - setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote); - setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); - setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand); - setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1294,14 +1294,11 @@ NumElts = EltVT.getVectorNumElements(); EltVT = EltVT.getVectorElementType(); // vectors of f16 are loaded/stored as multiples of v2f16 elements. - if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) { - assert(NumElts % 2 == 0 && "Vector must have even number of elements"); - EltVT = MVT::v2f16; - NumElts /= 2; - } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) { - assert(NumElts % 2 == 0 && "Vector must have even number of elements"); - EltVT = MVT::v2bf16; - NumElts /= 2; + if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || + (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) { + assert(NumElts % 2 == 0 && "Vector must have even number of elements"); + EltVT = N->getValueType(0); + NumElts /= 2; } }
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -150,23 +150,11 @@ } static bool Isv2f16Orv2bf16Type(MVT VT) { - switch (VT.SimpleTy) { - default: - return false; - case MVT::v2f16: - case MVT::v2bf16: - return true; - } + return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16); } static bool Isf16Orbf16Type(MVT VT) { - switch (VT.SimpleTy) { - default: - return false; - case MVT::f16: - case MVT::bf16: - return true; - } + return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16); } /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive @@ -624,9 +612,6 @@ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { setFP16OperationAction(Op, MVT::f16, Legal, Promote); setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); - } - - for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { setBF16OperationAction(Op, MVT::bf16, Legal, Promote); setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); } @@ -693,20 +678,18 @@ }; for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote); + setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); } for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand); setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); - } - for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { - setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote); - setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); - setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand); - setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1294,14 +1294,11 @@ NumElts = EltVT.getVectorNumElements(); EltVT = EltVT.getVectorElementType(); // vectors of f16 are loaded/stored as multiples of v2f16 elements. - if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) { - assert(NumElts % 2 == 0 && "Vector must have even number of elements"); - EltVT = MVT::v2f16; - NumElts /= 2; - } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) { - assert(NumElts % 2 == 0 && "Vector must have even number of elements"); - EltVT = MVT::v2bf16; - NumElts /= 2; + if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || + (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) { + assert(NumElts % 2 == 0 && "Vector must have even number of elements"); + EltVT = N->getValueType(0); + NumElts /= 2; } }
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits