https://github.com/xlauko created 
https://github.com/llvm/llvm-project/pull/171083

This had to fix memory and conversion bugs due to now immediate
conversion patterns and no longer present original MLIR.

>From b93a2e68fdc1a42d91e91eb543106054087aa8f1 Mon Sep 17 00:00:00 2001
From: xlauko <[email protected]>
Date: Mon, 8 Dec 2025 07:32:02 +0100
Subject: [PATCH] [CIR] Make CIR-to-LLVM a one shot conversion

This had to fix memory and conversion bugs due to now immediate
conversion patterns and no longer present original MLIR.
---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |   2 +-
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 127 ++++++++++++------
 2 files changed, 86 insertions(+), 43 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 3d6de2a97d650..1b7fa953e8a37 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2160,7 +2160,7 @@ def CIR_GlobalOp : CIR_Op<"global", [
       cir::GlobalOp op, mlir::Attribute init,
       mlir::ConversionPatternRewriter &rewriter) const;
 
-    void setupRegionInitializedLLVMGlobalOp(
+    mlir::LLVM::GlobalOp setupRegionInitializedLLVMGlobalOp(
         cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const;
 
     mutable mlir::LLVM::ComdatOp comdatOp = nullptr;
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index aea9e26341f8f..8e2b47dbeb629 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -716,8 +716,7 @@ mlir::LogicalResult 
CIRToLLVMIsFPClassOpLowering::matchAndRewrite(
 mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite(
     cir::AssumeOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
-  auto cond = adaptor.getPredicate();
-  rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, cond);
+  rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, 
adaptor.getPredicate());
   return mlir::success();
 }
 
@@ -1130,11 +1129,11 @@ mlir::LogicalResult 
CIRToLLVMBrCondOpLowering::matchAndRewrite(
   // ZExtOp and if so, delete it if it has a single use.
   assert(!cir::MissingFeatures::zextOp());
 
-  mlir::Value i1Condition = adaptor.getCond();
-
+  
   rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
-      brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
-      brOp.getDestFalse(), adaptor.getDestOperandsFalse());
+      brOp, adaptor.getCond(), brOp.getDestTrue(),
+      adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
+      adaptor.getDestOperandsFalse());
 
   return mlir::success();
 }
@@ -1942,12 +1941,12 @@ mlir::LogicalResult 
CIRToLLVMFuncOpLowering::matchAndRewriteAlias(
   lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
 
   mlir::Location loc = op.getLoc();
+  mlir::OpBuilder builder(op.getContext());
   auto aliasOp = rewriter.replaceOpWithNewOp<mlir::LLVM::AliasOp>(
       op, ty, convertLinkage(op.getLinkage()), op.getName(), op.getDsoLocal(),
       /*threadLocal=*/false, attributes);
 
   // Create the alias body
-  mlir::OpBuilder builder(op.getContext());
   mlir::Block *block = builder.createBlock(&aliasOp.getInitializerRegion());
   builder.setInsertionPointToStart(block);
   // The type of AddressOfOp is always a pointer.
@@ -2053,7 +2052,8 @@ mlir::LogicalResult 
CIRToLLVMGetGlobalOpLowering::matchAndRewrite(
 
 /// Replace CIR global with a region initialized LLVM global and update
 /// insertion point to the end of the initializer block.
-void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
+mlir::LLVM::GlobalOp
+CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
     cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
   const mlir::Type llvmType =
       convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
@@ -2080,6 +2080,7 @@ void 
CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
           isDsoLocal, isThreadLocal, comdatAttr, attributes);
   newGlobalOp.getRegion().emplaceBlock();
   rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
+  return newGlobalOp;
 }
 
 mlir::LogicalResult
@@ -2097,8 +2098,9 @@ 
CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
   // should be updated. For now, we use a custom op to initialize globals
   // to the appropriate value.
   const mlir::Location loc = op.getLoc();
-  setupRegionInitializedLLVMGlobalOp(op, rewriter);
-  CIRAttrToValue valueConverter(op, rewriter, typeConverter);
+  mlir::LLVM::GlobalOp newGlobalOp =
+      setupRegionInitializedLLVMGlobalOp(op, rewriter);
+  CIRAttrToValue valueConverter(newGlobalOp, rewriter, typeConverter);
   mlir::Value value = valueConverter.visit(init);
   mlir::LLVM::ReturnOp::create(rewriter, loc, value);
   return mlir::success();
@@ -2795,42 +2797,45 @@ mlir::LogicalResult 
CIRToLLVMShiftOpLowering::matchAndRewrite(
 mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
     cir::SelectOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
-  auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
-    auto definingOp = value.getDefiningOp<cir::ConstantOp>();
-    if (!definingOp)
-      return {};
 
-    auto constValue = definingOp.getValueAttr<cir::BoolAttr>();
-    if (!constValue)
-      return {};
+  // Helper to extract boolean constant value
+  auto getConstantBool = [](mlir::Value value) -> std::optional<bool> {
+    auto constOp = value.getDefiningOp<mlir::LLVM::ConstantOp>();
+    if (!constOp)
+      return std::nullopt;
 
-    return constValue;
+    auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(constOp.getValue());
+    if (!intAttr)
+      return std::nullopt;
+
+    return !intAttr.getValue().isZero();
   };
 
+  mlir::Value condition = adaptor.getCondition();
+  mlir::Value trueValue = adaptor.getTrueValue();
+  mlir::Value falseValue = adaptor.getFalseValue();
+
   // Two special cases in the LLVMIR codegen of select op:
-  // - select %0, %1, false => and %0, %1
-  // - select %0, true, %1 => or %0, %1
+  // - select %cond, %val, false => and %cond, %val
+  // - select %cond, true, %val => or %cond, %val
   if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
-    cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
-    cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
-    if (falseValue && !falseValue.getValue()) {
-      // select %0, %1, false => and %0, %1
-      rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, 
adaptor.getCondition(),
-                                                     adaptor.getTrueValue());
+    // Optimization: select %cond, %val, false => and %cond, %val
+    std::optional<bool> falseConst = getConstantBool(falseValue);
+    if (falseConst && !*falseConst) {
+      rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, condition, trueValue);
       return mlir::success();
     }
-    if (trueValue && trueValue.getValue()) {
-      // select %0, true, %1 => or %0, %1
-      rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
-                                                    adaptor.getFalseValue());
+    // Optimization: select %cond, true, %val => or %cond, %val
+    std::optional<bool> trueConst = getConstantBool(trueValue);
+    if (trueConst && *trueConst) {
+      rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, condition, falseValue);
       return mlir::success();
     }
   }
 
-  mlir::Value llvmCondition = adaptor.getCondition();
-  rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
-      op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
-
+  // Default case: emit standard LLVM select
+  rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(op, condition, trueValue,
+                                                    falseValue);
   return mlir::success();
 }
 
@@ -3008,13 +3013,12 @@ static void buildCtorDtorList(
   mlir::LLVM::ReturnOp::create(builder, loc, result);
 }
 
-// The applyPartialConversion function traverses blocks in the dominance order,
-// so it does not lower and operations that are not reachachable from the
-// operations passed in as arguments. Since we do need to lower such code in
-// order to avoid verification errors occur, we cannot just pass the module op
-// to applyPartialConversion. We must build a set of unreachable ops and
-// explicitly add them, along with the module, to the vector we pass to
-// applyPartialConversion.
+// The applyFullConversion function performs a full conversion that legalizes
+// all operations. It traverses all operations including unreachable blocks, so
+// we need to collect unreachable operations and explicitly add them along with
+// the module to ensure they are converted. We use one-shot conversion mode
+// (allowPatternRollback = false) for better performance by avoiding rollback
+// state maintenance.
 //
 // For instance, this CIR code:
 //
@@ -3135,7 +3139,10 @@ void ConvertCIRToLLVMPass::runOnOperation() {
   ops.push_back(module);
   collectUnreachable(module, ops);
 
-  if (failed(applyPartialConversion(ops, target, std::move(patterns))))
+  mlir::ConversionConfig config;
+  config.allowPatternRollback = false;
+
+  if (failed(applyFullConversion(ops, target, std::move(patterns), config)))
     signalPassFailure();
 
   // Emit the llvm.global_ctors array.
@@ -3750,11 +3757,31 @@ mlir::LogicalResult 
CIRToLLVMComplexRealOpLowering::matchAndRewrite(
     mlir::ConversionPatternRewriter &rewriter) const {
   mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
   mlir::Value operand = adaptor.getOperand();
+
+  // FIXME:
+  // Check if we're extracting from a ComplexCreate that was already lowered
+  // Pattern: insertvalue(insertvalue(undef, real, 0), imag, 1) -> just use
+  // 'real'
+  if (auto secondInsert = operand.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+    if (secondInsert.getPosition() == llvm::ArrayRef<int64_t>{1}) {
+      if (auto firstInsert = secondInsert.getContainer()
+                                 .getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+        if (firstInsert.getPosition() == llvm::ArrayRef<int64_t>{0}) {
+          // This is the pattern we're looking for - return the real component
+          // directly
+          rewriter.replaceOp(op, firstInsert.getValue());
+          return mlir::success();
+        }
+      }
+    }
+  }
+
   if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
     operand = mlir::LLVM::ExtractValueOp::create(
         rewriter, op.getLoc(), resultLLVMTy, operand,
         llvm::ArrayRef<std::int64_t>{0});
   }
+ 
   rewriter.replaceOp(op, operand);
   return mlir::success();
 }
@@ -3815,6 +3842,22 @@ mlir::LogicalResult 
CIRToLLVMComplexImagOpLowering::matchAndRewrite(
   mlir::Value operand = adaptor.getOperand();
   mlir::Location loc = op.getLoc();
 
+  // FIXME:
+  // Check if we're extracting from a ComplexCreate that was already lowered
+  // Pattern: insertvalue(insertvalue(undef, real, 0), imag, 1) -> just use
+  // 'imag'
+  if (auto secondInsert = operand.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+    if (secondInsert.getPosition() == llvm::ArrayRef<int64_t>{1}) {
+      if (secondInsert.getContainer()
+              .getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+        // This is the pattern we're looking for - return the imag component
+        // directly
+        rewriter.replaceOp(op, secondInsert.getValue());
+        return mlir::success();
+      }
+    }
+  }
+
   if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
     operand = mlir::LLVM::ExtractValueOp::create(
         rewriter, loc, resultLLVMTy, operand, llvm::ArrayRef<std::int64_t>{1});

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to