llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) <details> <summary>Changes</summary> This change adds support for the Ternary op for VectorType Issue https://github.com/llvm/llvm-project/issues/136487 --- Full diff: https://github.com/llvm/llvm-project/pull/142393.diff 7 Files Affected: - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+36) - (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+30) - (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+18) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+16-1) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10) - (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+15) - (modified) clang/test/CIR/CodeGen/vector.cpp (+16-1) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 07851610a2abd..eb02d849b79f6 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2190,4 +2190,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// VecTernaryOp +//===----------------------------------------------------------------------===// + +def VecTernaryOp : CIR_Op<"vec.ternary", + [Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> { + let summary = "The `cond ? a : b` ternary operator for vector types"; + let description = [{ + The `cir.vec.ternary` operation represents the C/C++ ternary operator, + `?:`, for vector types, which does a `select` on individual elements of the + vectors. Unlike a regular `?:` operator, there is no short circuiting. All + three arguments are always evaluated. Because there is no short + circuiting, there are no regions in this operation, unlike cir.ternary. + + The first argument is a vector of integral type. The second and third + arguments are vectors of the same type and have the same number of elements + as the first argument. + + The result is a vector of the same type as the second and third arguments. + Each element of the result is `(bool)a[n] ? b[n] : c[n]`. + }]; + + let arguments = (ins + IntegerVector:$cond, + CIR_VectorType:$vec1, + CIR_VectorType:$vec2 + ); + + let results = (outs CIR_VectorType:$result); + let assemblyFormat = [{ + `(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,` + qualified(type($vec1)) attr-dict + }]; + let hasVerifier = 1; +} + #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 8448c164a5e58..5ae727dff1095 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -193,6 +193,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> { e->getSourceRange().getBegin()); } + mlir::Value + VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) { + mlir::Location loc = cgf.getLoc(e->getSourceRange()); + Expr *condExpr = e->getCond(); + Expr *lhsExpr = e->getTrueExpr(); + Expr *rhsExpr = e->getFalseExpr(); + + // OpenCL: If the condition is a vector, we can treat this condition like + // the select function. + if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) || + condExpr->getType()->isExtVectorType()) { + cgf.getCIRGenModule().errorNYI(loc, + "TernaryOp OpenCL VectorType condition"); + return {}; + } + + if (condExpr->getType()->isVectorType() || + condExpr->getType()->isSveVLSBuiltinType()) { + assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI"); + mlir::Value condValue = Visit(condExpr); + mlir::Value lhsValue = Visit(lhsExpr); + mlir::Value rhsValue = Visit(rhsExpr); + return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue, + rhsValue); + } + + cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types"); + return {}; + } + mlir::Value VisitMemberExpr(MemberExpr *e); mlir::Value VisitInitListExpr(InitListExpr *e); diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 36f050de9f8bb..1236c455304a9 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// VecTernaryOp +//===----------------------------------------------------------------------===// + +LogicalResult cir::VecTernaryOp::verify() { + // Verify that the condition operand has the same number of elements as the + // other operands. (The automatic verification already checked that all + // operands are vector types and that the second and third operands are the + // same type.) + if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() != + getVec1().getType().getSize()) { + return emitOpError() << ": the number of elements in " + << getCond().getType() << " and " + << getVec1().getType() << " don't match"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index b07e61638c3b4..e5a26260dc8cc 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1708,7 +1708,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecExtractOpLowering, CIRToLLVMVecInsertOpLowering, CIRToLLVMVecCmpOpLowering, - CIRToLLVMVecShuffleDynamicOpLowering + CIRToLLVMVecShuffleDynamicOpLowering, + CIRToLLVMVecTernaryOpLowering // clang-format on >(converter, patterns.getContext()); @@ -1916,6 +1917,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite( + cir::VecTernaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // Convert `cond` into a vector of i1, then use that in a `select` op. + mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>( + op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), + rewriter.create<mlir::LLVM::ZeroOp>( + op.getCond().getLoc(), + typeConverter->convertType(op.getCond().getType()))); + rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( + op, bitVec, adaptor.getVec1(), adaptor.getVec2()); + return mlir::success(); +} + std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { return std::make_unique<ConvertCIRToLLVMPass>(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 6b8862db2c8be..ed369ff15a710 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -363,6 +363,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVecTernaryOpLowering + : public mlir::OpConversionPattern<cir::VecTernaryOp> { +public: + using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VecTernaryOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + } // namespace direct } // namespace cir diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp index 8a0479fc1d088..53258845c2169 100644 --- a/clang/test/CIR/CodeGen/vector-ext.cpp +++ b/clang/test/CIR/CodeGen/vector-ext.cpp @@ -1091,3 +1091,18 @@ void foo17() { // OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16 // OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16 // OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16> + +void foo20() { + vi4 a; + vi4 b; + vi4 c; + vi4 r = c ? a : b; +} + +// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> + +// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} + +// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp index 4c50f68a56162..49f142d110a81 100644 --- a/clang/test/CIR/CodeGen/vector.cpp +++ b/clang/test/CIR/CodeGen/vector.cpp @@ -1069,4 +1069,19 @@ void foo17() { // OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16 // OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16 -// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16> \ No newline at end of file +// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16> + +void foo20() { + vi4 a; + vi4 b; + vi4 c; + vi4 r = c ? a : b; +} + +// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> + +// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} + +// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer +// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}} `````````` </details> https://github.com/llvm/llvm-project/pull/142393 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits