https://github.com/mmha created https://github.com/llvm/llvm-project/pull/137184

This patch adds TernaryOp to CIR plus a pass that flattens the operator in 
FlattenCFG.

This is the first PR out of (probably) 3 wrt. TernaryOp. I split the patches up 
to make reviewing them easier. As such, this PR is only about adding the CIR 
operation. The next PR will be about the CodeGen bits from the C++ conditional 
operator and the final one will add the cir-simplify transform for TernaryOp 
and SelectOp.

>From 1eed90e3859c2ad8d703708f89976cad8f0faeec Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhaf...@nvidia.com>
Date: Thu, 24 Apr 2025 16:12:37 +0200
Subject: [PATCH] [CIR] Upstream TernaryOp

This patch adds TernaryOp to CIR plus a pass that flattens the operator in 
FlattenCFG.
---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  | 57 +++++++++++++++-
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 42 ++++++++++++
 .../lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 60 ++++++++++++++--
 clang/test/CIR/IR/ternary.cir                 | 30 ++++++++
 clang/test/CIR/Lowering/ternary.cir           | 30 ++++++++
 clang/test/CIR/Transforms/ternary.cir         | 68 +++++++++++++++++++
 6 files changed, 280 insertions(+), 7 deletions(-)
 create mode 100644 clang/test/CIR/IR/ternary.cir
 create mode 100644 clang/test/CIR/Lowering/ternary.cir
 create mode 100644 clang/test/CIR/Transforms/ternary.cir

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 81b447f31feca..76ad5c3666c1b 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -609,8 +609,8 @@ def ConditionOp : CIR_Op<"condition", [
 
//===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "ScopeOp", "WhileOp",
-                                            "ForOp", "DoWhileOp"]>]> {
+                               ParentOneOf<["IfOp", "TernaryOp", "ScopeOp",
+                                            "WhileOp", "ForOp", 
"DoWhileOp"]>]> {
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on different CIR operations,
@@ -1246,6 +1246,59 @@ def SelectOp : CIR_Op<"select", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+def TernaryOp : CIR_Op<"ternary",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+  let summary = "The `cond ? a : b` C/C++ ternary operation";
+  let description = [{
+    The `cir.ternary` operation represents C/C++ ternary, much like a `select`
+    operation. The first argument is a `cir.bool` condition to evaluate, 
followed
+    by two regions to execute (true or false). This is different from `cir.if`
+    since each region is one block sized and the `cir.yield` closing the block
+    scope should have one argument.
+
+    Example:
+
+    ```mlir
+    // x = cond ? a : b;
+
+    %x = cir.ternary (%cond, true_region {
+      ...
+      cir.yield %a : i32
+    }, false_region {
+      ...
+      cir.yield %b : i32
+    }) -> i32
+    ```
+  }];
+  let arguments = (ins CIR_BoolType:$cond);
+  let regions = (region AnyRegion:$trueRegion,
+                        AnyRegion:$falseRegion);
+  let results = (outs Optional<CIR_AnyType>:$result);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$cond,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$trueBuilder,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$falseBuilder)
+      >
+  ];
+
+  // All constraints already verified elsewhere.
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    `(` $cond `,`
+      `true` $trueRegion `,`
+      `false` $falseRegion
+    `)` `:` functional-type(operands, results) attr-dict
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // GlobalOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 89daf20c5f478..e80d243cb396f 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1058,6 +1058,48 @@ LogicalResult cir::BinOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void cir::TernaryOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `true` and the `false` region branch back to the parent operation.
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor(this->getODSResults(0)));
+    return;
+  }
+
+  // If the condition isn't constant, both regions may be executed.
+  regions.push_back(RegionSuccessor(&getTrueRegion()));
+  regions.push_back(RegionSuccessor(&getFalseRegion()));
+}
+
+void cir::TernaryOp::build(
+    OpBuilder &builder, OperationState &result, Value cond,
+    function_ref<void(OpBuilder &, Location)> trueBuilder,
+    function_ref<void(OpBuilder &, Location)> falseBuilder) {
+  result.addOperands(cond);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *trueRegion = result.addRegion();
+  Block *block = builder.createBlock(trueRegion);
+  trueBuilder(builder, result.location);
+  Region *falseRegion = result.addRegion();
+  builder.createBlock(falseRegion);
+  falseBuilder(builder, result.location);
+
+  auto yield = dyn_cast<YieldOp>(block->getTerminator());
+  assert((yield && yield.getNumOperands() <= 1) &&
+         "expected zero or one result type");
+  if (yield.getNumOperands() == 1)
+    result.addTypes(TypeRange{yield.getOperandTypes().front()});
+}
+
 
//===----------------------------------------------------------------------===//
 // ShiftOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp 
b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 72ccfa8d4e14e..295fa748b1624 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
   }
 };
 
+class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
+public:
+  using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::TernaryOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Block *condBlock = rewriter.getInsertionBlock();
+    Block::iterator opPosition = rewriter.getInsertionPoint();
+    Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+    llvm::SmallVector<mlir::Location, 2> locs;
+    // Ternary result is optional, make sure to populate the location only
+    // when relevant.
+    if (op->getResultTypes().size())
+      locs.push_back(loc);
+    auto *continueBlock =
+        rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
+    rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
+
+    Region &trueRegion = op.getTrueRegion();
+    Block *trueBlock = &trueRegion.front();
+    mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&trueRegion.back());
+    auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
+
+    rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(trueRegion, continueBlock);
+
+    Block *falseBlock = continueBlock;
+    Region &falseRegion = op.getFalseRegion();
+
+    falseBlock = &falseRegion.front();
+    mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&falseRegion.back());
+    cir::YieldOp falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, 
falseYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(falseRegion, continueBlock);
+
+    rewriter.setInsertionPointToEnd(condBlock);
+    rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
+
+    rewriter.replaceOp(op, continueBlock->getArguments());
+
+    // Ok, we're done!
+    return mlir::success();
+  }
+};
+
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, 
CIRScopeOpFlattening>(
-          patterns.getContext());
+  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
+               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
+      patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
   getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
-    assert(!cir::MissingFeatures::ternaryOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
+    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir
new file mode 100644
index 0000000000000..3827dc77726df
--- /dev/null
+++ b/clang/test/CIR/IR/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-opt %s | cir-opt | FileCheck %s
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// CHECK: module  {
+
+// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
+// CHECK:   %0 = cir.ternary(%arg0, true {
+// CHECK:     %1 = cir.const #cir.int<0> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }, false {
+// CHECK:     %1 = cir.const #cir.int<1> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }) : (!cir.bool) -> !u32i
+// CHECK:   cir.return %0 : !u32i
+// CHECK: }
+
+// CHECK: }
diff --git a/clang/test/CIR/Lowering/ternary.cir 
b/clang/test/CIR/Lowering/ternary.cir
new file mode 100644
index 0000000000000..247c6ae3a1e17
--- /dev/null
+++ b/clang/test/CIR/Lowering/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
+
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// LLVM-LABEL: define i32 {{.*}}@blue(
+// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
+// LLVM:   br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label 
%[[B2:[[:alnum:]]+]]
+// LLVM: [[B1]]:
+// LLVM:   br label %[[M:[[:alnum:]]+]]
+// LLVM: [[B2]]:
+// LLVM:   br label %[[M]]
+// LLVM: [[M]]:
+// LLVM:   [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
+// LLVM:   br label %[[B3:[[:alnum:]]+]]
+// LLVM: [[B3]]:
+// LLVM:   ret i32 [[R]]
diff --git a/clang/test/CIR/Transforms/ternary.cir 
b/clang/test/CIR/Transforms/ternary.cir
new file mode 100644
index 0000000000000..67ef7f95a6b52
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @foo(%arg0: !s32i) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+    %3 = cir.const #cir.int<0> : !s32i
+    %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+    %5 = cir.ternary(%4, true {
+      %7 = cir.const #cir.int<3> : !s32i
+      cir.yield %7 : !s32i
+    }, false {
+      %7 = cir.const #cir.int<5> : !s32i
+      cir.yield %7 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+    %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+    cir.return %6 : !s32i
+  }
+
+// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
+// CHECK:   %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 
: i64}
+// CHECK:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 
4 : i64}
+// CHECK:   cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK:   %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK:   %3 = cir.const #cir.int<0> : !s32i
+// CHECK:   %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+// CHECK:    cir.brcond %4 ^bb1, ^bb2
+// CHECK:  ^bb1:  // pred: ^bb0
+// CHECK:    %5 = cir.const #cir.int<3> : !s32i
+// CHECK:    cir.br ^bb3(%5 : !s32i)
+// CHECK:  ^bb2:  // pred: ^bb0
+// CHECK:    %6 = cir.const #cir.int<5> : !s32i
+// CHECK:    cir.br ^bb3(%6 : !s32i)
+// CHECK:  ^bb3(%7: !s32i):  // 2 preds: ^bb1, ^bb2
+// CHECK:    cir.br ^bb4
+// CHECK:  ^bb4:  // pred: ^bb3
+// CHECK:    cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK:    %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CHECK:    cir.return %8 : !s32i
+// CHECK:  }
+
+  cir.func @foo2(%arg0: !cir.bool) {
+    cir.ternary(%arg0, true {
+      cir.yield
+    }, false {
+      cir.yield
+    }) : (!cir.bool) -> ()
+    cir.return
+  }
+
+// CHECK: cir.func @foo2(%arg0: !cir.bool) {
+// CHECK:   cir.brcond %arg0 ^bb1, ^bb2
+// CHECK: ^bb1:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb2:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+// CHECK:   cir.br ^bb4
+// CHECK: ^bb4:  // pred: ^bb3
+// CHECK:   cir.return
+// CHECK: }
+
+}

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to