https://github.com/mmha created https://github.com/llvm/llvm-project/pull/137184
This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG. This is the first PR out of (probably) 3 wrt. TernaryOp. I split the patches up to make reviewing them easier. As such, this PR is only about adding the CIR operation. The next PR will be about the CodeGen bits from the C++ conditional operator and the final one will add the cir-simplify transform for TernaryOp and SelectOp. >From 1eed90e3859c2ad8d703708f89976cad8f0faeec Mon Sep 17 00:00:00 2001 From: Morris Hafner <mhaf...@nvidia.com> Date: Thu, 24 Apr 2025 16:12:37 +0200 Subject: [PATCH] [CIR] Upstream TernaryOp This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG. --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 57 +++++++++++++++- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 42 ++++++++++++ .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 60 ++++++++++++++-- clang/test/CIR/IR/ternary.cir | 30 ++++++++ clang/test/CIR/Lowering/ternary.cir | 30 ++++++++ clang/test/CIR/Transforms/ternary.cir | 68 +++++++++++++++++++ 6 files changed, 280 insertions(+), 7 deletions(-) create mode 100644 clang/test/CIR/IR/ternary.cir create mode 100644 clang/test/CIR/Lowering/ternary.cir create mode 100644 clang/test/CIR/Transforms/ternary.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 81b447f31feca..76ad5c3666c1b 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, - ParentOneOf<["IfOp", "ScopeOp", "WhileOp", - "ForOp", "DoWhileOp"]>]> { + ParentOneOf<["IfOp", "TernaryOp", "ScopeOp", + "WhileOp", "ForOp", "DoWhileOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on different CIR operations, @@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure, }]; } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +def TernaryOp : CIR_Op<"ternary", + [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> { + let summary = "The `cond ? a : b` C/C++ ternary operation"; + let description = [{ + The `cir.ternary` operation represents C/C++ ternary, much like a `select` + operation. The first argument is a `cir.bool` condition to evaluate, followed + by two regions to execute (true or false). This is different from `cir.if` + since each region is one block sized and the `cir.yield` closing the block + scope should have one argument. + + Example: + + ```mlir + // x = cond ? a : b; + + %x = cir.ternary (%cond, true_region { + ... + cir.yield %a : i32 + }, false_region { + ... + cir.yield %b : i32 + }) -> i32 + ``` + }]; + let arguments = (ins CIR_BoolType:$cond); + let regions = (region AnyRegion:$trueRegion, + AnyRegion:$falseRegion); + let results = (outs Optional<CIR_AnyType>:$result); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "mlir::Value":$cond, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder) + > + ]; + + // All constraints already verified elsewhere. + let hasVerifier = 0; + + let assemblyFormat = [{ + `(` $cond `,` + `true` $trueRegion `,` + `false` $falseRegion + `)` `:` functional-type(operands, results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 89daf20c5f478..e80d243cb396f 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void cir::TernaryOp::getSuccessorRegions( + mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + // The `true` and the `false` region branch back to the parent operation. + if (!point.isParent()) { + regions.push_back(RegionSuccessor(this->getODSResults(0))); + return; + } + + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&getTrueRegion())); + regions.push_back(RegionSuccessor(&getFalseRegion())); +} + +void cir::TernaryOp::build( + OpBuilder &builder, OperationState &result, Value cond, + function_ref<void(OpBuilder &, Location)> trueBuilder, + function_ref<void(OpBuilder &, Location)> falseBuilder) { + result.addOperands(cond); + OpBuilder::InsertionGuard guard(builder); + Region *trueRegion = result.addRegion(); + Block *block = builder.createBlock(trueRegion); + trueBuilder(builder, result.location); + Region *falseRegion = result.addRegion(); + builder.createBlock(falseRegion); + falseBuilder(builder, result.location); + + auto yield = dyn_cast<YieldOp>(block->getTerminator()); + assert((yield && yield.getNumOperands() <= 1) && + "expected zero or one result type"); + if (yield.getNumOperands() == 1) + result.addTypes(TypeRange{yield.getOperandTypes().front()}); +} + //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 72ccfa8d4e14e..295fa748b1624 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening } }; +class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { +public: + using OpRewritePattern<cir::TernaryOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cir::TernaryOp op, + mlir::PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Block *condBlock = rewriter.getInsertionBlock(); + Block::iterator opPosition = rewriter.getInsertionPoint(); + Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + llvm::SmallVector<mlir::Location, 2> locs; + // Ternary result is optional, make sure to populate the location only + // when relevant. + if (op->getResultTypes().size()) + locs.push_back(loc); + auto *continueBlock = + rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); + rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + + Region &trueRegion = op.getTrueRegion(); + Block *trueBlock = &trueRegion.front(); + mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&trueRegion.back()); + auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); + + rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(trueRegion, continueBlock); + + Block *falseBlock = continueBlock; + Region &falseRegion = op.getFalseRegion(); + + falseBlock = &falseRegion.front(); + mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&falseRegion.back()); + cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); + rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(falseRegion, continueBlock); + + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); + + rewriter.replaceOp(op, continueBlock->getArguments()); + + // Ok, we're done! + return mlir::success(); + } +}; + void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns - .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>( - patterns.getContext()); + patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, + CIRScopeOpFlattening, CIRTernaryOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() { getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); - assert(!cir::MissingFeatures::ternaryOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, LoopOpInterface>(op)) + if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir new file mode 100644 index 0000000000000..3827dc77726df --- /dev/null +++ b/clang/test/CIR/IR/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-opt %s | cir-opt | FileCheck %s +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// CHECK: module { + +// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i { +// CHECK: %0 = cir.ternary(%arg0, true { +// CHECK: %1 = cir.const #cir.int<0> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }, false { +// CHECK: %1 = cir.const #cir.int<1> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }) : (!cir.bool) -> !u32i +// CHECK: cir.return %0 : !u32i +// CHECK: } + +// CHECK: } diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir new file mode 100644 index 0000000000000..247c6ae3a1e17 --- /dev/null +++ b/clang/test/CIR/Lowering/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s + +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// LLVM-LABEL: define i32 {{.*}}@blue( +// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]]) +// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]] +// LLVM: [[B1]]: +// LLVM: br label %[[M:[[:alnum:]]+]] +// LLVM: [[B2]]: +// LLVM: br label %[[M]] +// LLVM: [[M]]: +// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ] +// LLVM: br label %[[B3:[[:alnum:]]+]] +// LLVM: [[B3]]: +// LLVM: ret i32 [[R]] diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir new file mode 100644 index 0000000000000..67ef7f95a6b52 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary.cir @@ -0,0 +1,68 @@ +// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @foo(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i + %3 = cir.const #cir.int<0> : !s32i + %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool + %5 = cir.ternary(%4, true { + %7 = cir.const #cir.int<3> : !s32i + cir.yield %7 : !s32i + }, false { + %7 = cir.const #cir.int<5> : !s32i + cir.yield %7 : !s32i + }) : (!cir.bool) -> !s32i + cir.store %5, %1 : !s32i, !cir.ptr<!s32i> + %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i + cir.return %6 : !s32i + } + +// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i { +// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} +// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} +// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> +// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CHECK: %3 = cir.const #cir.int<0> : !s32i +// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool +// CHECK: cir.brcond %4 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %5 = cir.const #cir.int<3> : !s32i +// CHECK: cir.br ^bb3(%5 : !s32i) +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: %6 = cir.const #cir.int<5> : !s32i +// CHECK: cir.br ^bb3(%6 : !s32i) +// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i> +// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i +// CHECK: cir.return %8 : !s32i +// CHECK: } + + cir.func @foo2(%arg0: !cir.bool) { + cir.ternary(%arg0, true { + cir.yield + }, false { + cir.yield + }) : (!cir.bool) -> () + cir.return + } + +// CHECK: cir.func @foo2(%arg0: !cir.bool) { +// CHECK: cir.brcond %arg0 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.return +// CHECK: } + +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits