https://github.com/AmrDeveloper created https://github.com/llvm/llvm-project/pull/142946
This change adds a folder for the VecTernaryOp Issue https://github.com/llvm/llvm-project/issues/136487 >From ac8277b48d0affa78f5e5e943e0179c27dd033ec Mon Sep 17 00:00:00 2001 From: AmrDeveloper <am...@programmer.net> Date: Thu, 5 Jun 2025 13:08:57 +0200 Subject: [PATCH] [CIR] Implement folder for VecTernaryOp --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 ++ clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 35 +++++++++++++++++++ .../Dialect/Transforms/CIRCanonicalize.cpp | 6 ++-- .../CIR/Transforms/vector-ternary-fold.cir | 20 +++++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 clang/test/CIR/Transforms/vector-ternary-fold.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 00878f7dd8ed7..eb439f7aa1527 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2228,7 +2228,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary", `(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,` qualified(type($lhs)) attr-dict }]; + let hasVerifier = 1; + let hasFolder = 1; } #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index fa7fb592a3cd6..f585254d3340b 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1638,6 +1638,41 @@ LogicalResult cir::VecTernaryOp::verify() { return success(); } +OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) { + mlir::Attribute cond = adaptor.getCond(); + mlir::Attribute lhs = adaptor.getLhs(); + mlir::Attribute rhs = adaptor.getRhs(); + + if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) && + mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) && + mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) { + auto condVec = mlir::cast<cir::ConstVectorAttr>(cond); + auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs); + auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs); + + mlir::ArrayAttr condElts = condVec.getElts(); + + SmallVector<mlir::Attribute, 16> elements; + elements.reserve(condElts.size()); + + for (const auto &[idx, condAttr] : + llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) { + if (condAttr.getSInt()) { + elements.push_back(lhsVec.getElts()[idx]); + continue; + } + + elements.push_back(rhsVec.getElts()[idx]); + } + + cir::VectorType vecTy = getLhs().getType(); + return cir::ConstVectorAttr::get( + vecTy, mlir::ArrayAttr::get(getContext(), elements)); + } + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index 7d03e374c27e8..aa3e97033cdda 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() { assert(!cir::MissingFeatures::complexRealOp()); assert(!cir::MissingFeatures::complexImagOp()); assert(!cir::MissingFeatures::callOp()); - // CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform - // a manual `fold` in applyOpPatternsGreedily. + // CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are + // here to perform a manual `fold` in applyOpPatternsGreedily. if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, - VecExtractOp, VecShuffleDynamicOp>(op)) + VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/Transforms/vector-ternary-fold.cir b/clang/test/CIR/Transforms/vector-ternary-fold.cir new file mode 100644 index 0000000000000..f2e18576da74b --- /dev/null +++ b/clang/test/CIR/Transforms/vector-ternary-fold.cir @@ -0,0 +1,20 @@ +// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> { + %cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i> + %lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + %rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i> + %res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i> + cir.return %res : !cir.vector<4 x !s32i> + } + + // [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8] + // CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> { + // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i> + // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i> +} + + _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits