https://github.com/Andres-Salamanca created https://github.com/llvm/llvm-project/pull/176664
While working on the lowering of `BlockAddressOp`, I ran into a limitation of the current `CIRLoweringEmitter` infrastructure: some CIR ops require shared mutable state across multiple lowering patterns, but the generated patterns always use a fixed default constructor and cannot accept additional context. **Examples** * `BlockAddressOp` needs to coordinate with the lowering of its referenced `LabelOp`, requiring shared bookkeeping (e.g. `LLVMBlockAddressInfo`) to track lowered labels and resolve unresolved block addresses once the corresponding `BlockTagOp` is available. * This is not unique to block addresses: for example, `CIRToLLVMAllocaOpLowering` will also need access to shared state such as `stringGlobalsMap`, `argStringGlobalsMap`, and `argsVarMap`. **Possible solution** * **Disable generated lowering**: mark the op with `hasLLVMLowering = false` and manually define and register the lowering patterns. This works, but feels misleading since the operation does have an LLVM lowering it just cannot be generated automatically. This ambiguity also suggests that `hasLLVMLowering` might be better named `genLLVMLowering`. * **Custom constructors (this PR)**: extend the TableGen emitter to allow defining a single custom constructor for the generated LLVM lowering pattern, similar to MLIR’s `OpBuilder`. When a custom constructor is provided, the default one is not generated and the pattern must be manually registered in the RewritePatternSet. >From 7643adcd6d1064e32d38844e5c6ef404e36c9b17 Mon Sep 17 00:00:00 2001 From: Andres Salamanca <[email protected]> Date: Sun, 21 Dec 2025 21:22:21 -0500 Subject: [PATCH] [CIR] Add custom constructor declaration to CIR lowering TableGen --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 26 +++++- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 15 ++++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 42 +++++++++ clang/utils/TableGen/CIRLoweringEmitter.cpp | 87 ++++++++++++++++--- 4 files changed, 158 insertions(+), 12 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index f965b8a5b7cff..b0513ed10221f 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -77,6 +77,11 @@ class LLVMLoweringInfo { string llvmOp = ""; } +class loweringBuilders<dag p, code b = ""> { + dag dagParams = p; + code body = b; +} + class CIR_Op<string mnemonic, list<Trait> traits = []> : Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo { // Should we generate an LLVM lowering pattern for this op? @@ -86,6 +91,21 @@ class CIR_Op<string mnemonic, list<Trait> traits = []> : // Extra class declarations to be included in the generated LLVM lowering // pattern. code extraLLVMLoweringPatternDecl = ""; + // Optional custom constructor for the generated LLVM lowering pattern. + // + // By default, CIR generates a standard lowering pattern constructor that + // receives the TypeConverter, MLIRContext, LowerModule, and DataLayout. + // Some operations require additional state or external context to be passed + // to the lowering pattern. + // + // This field allows an operation to define a single custom constructor for + // its LLVM lowering pattern. When specified, the default constructor is not + // generated, and only this custom constructor is emitted. + // NOTE: When providing a custom constructor, the corresponding lowering + // pattern is not automatically added to the RewritePatternSet. Users are + // responsible for manually registering the lowering pattern in the lowering + // pass. + loweringBuilders customLLVMLoweringConstructorDecl = ?; } //===----------------------------------------------------------------------===// @@ -1467,7 +1487,8 @@ def CIR_LabelOp : CIR_Op<"label", [AlwaysSpeculatable]> { let assemblyFormat = [{ $label attr-dict }]; let hasVerifier = 1; - let hasLLVMLowering = false; + let customLLVMLoweringConstructorDecl = + loweringBuilders<(ins "LLVMBlockAddressInfo &":$blockInfoAddr)>; } //===----------------------------------------------------------------------===// @@ -5735,6 +5756,9 @@ def CIR_BlockAddressOp : CIR_Op<"block_address", [Pure]> { let assemblyFormat = [{ $block_addr_info `:` qualified(type($addr)) attr-dict }]; + + let customLLVMLoweringConstructorDecl = + loweringBuilders<(ins "LLVMBlockAddressInfo &":$blockInfoAddr)>; } #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9a2815031970a..4b8cecc960024 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -3174,7 +3174,16 @@ void ConvertCIRToLLVMPass::runOnOperation() { std::unique_ptr<cir::LowerModule> lowerModule = prepareLowerModule(module); prepareTypeConverter(converter, dl, lowerModule.get()); + /// Tracks the state required to lower CIR `LabelOp` and `BlockAddressOp`. + /// Maps labels to their corresponding `BlockTagOp` and keeps bookkeeping + /// of unresolved `BlockAddressOp`s until they are matched with the + /// corresponding `BlockTagOp` in `resolveBlockAddressOp`. + LLVMBlockAddressInfo blockInfoAddr; mlir::RewritePatternSet patterns(&getContext()); + patterns.add<CIRToLLVMBlockAddressOpLowering>( + converter, patterns.getContext(), lowerModule.get(), dl, blockInfoAddr); + patterns.add<CIRToLLVMLabelOpLowering>(converter, patterns.getContext(), + lowerModule.get(), dl, blockInfoAddr); patterns.add< #define GET_LLVM_LOWERING_PATTERNS_LIST @@ -4196,6 +4205,12 @@ mlir::LogicalResult CIRToLLVMVAArgOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMLabelOpLowering::matchAndRewrite( + cir::LabelOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return mlir::failure(); +} + mlir::LogicalResult CIRToLLVMBlockAddressOpLowering::matchAndRewrite( cir::BlockAddressOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index d32f8603ee0be..7235794c3f9b4 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -36,6 +36,48 @@ void convertSideEffectForCall(mlir::Operation *callOp, bool isNothrow, mlir::LLVM::MemoryEffectsAttr &memoryEffect, bool &noUnwind, bool &willReturn); +struct LLVMBlockAddressInfo { + // Get the next tag index + uint32_t getTagIndex() { return blockTagOpIndex++; } + + void mapBlockTag(cir::BlockAddrInfoAttr info, mlir::LLVM::BlockTagOp tagOp) { + auto result = blockInfoToTagOp.try_emplace(info, tagOp); + assert(result.second && + "attempting to map a BlockTag operation that is already mapped"); + } + + // Lookup a BlockTagOp, may return nullptr if not yet registered. + mlir::LLVM::BlockTagOp lookupBlockTag(cir::BlockAddrInfoAttr info) const { + return blockInfoToTagOp.lookup(info); + } + + // Record an unresolved BlockAddressOp that needs patching later. + void addUnresolvedBlockAddress(mlir::LLVM::BlockAddressOp op, + cir::BlockAddrInfoAttr info) { + unresolvedBlockAddressOp.try_emplace(op, info); + } + + void clearUnresolvedMap() { unresolvedBlockAddressOp.clear(); } + + llvm::DenseMap<mlir::LLVM::BlockAddressOp, cir::BlockAddrInfoAttr> & + getUnresolvedBlockAddress() { + return unresolvedBlockAddressOp; + } + +private: + // Maps a (function name, label name) pair to the corresponding BlockTagOp. + // Used to resolve CIR LabelOps into their LLVM BlockTagOp. + llvm::DenseMap<cir::BlockAddrInfoAttr, mlir::LLVM::BlockTagOp> + blockInfoToTagOp; + // Tracks BlockAddressOps that could not yet be fully resolved because + // their BlockTagOp was not available at the time of lowering. The map + // stores the unresolved BlockAddressOp along with its (function name, label + // name) pair so it can be patched later. + llvm::DenseMap<mlir::LLVM::BlockAddressOp, cir::BlockAddrInfoAttr> + unresolvedBlockAddressOp; + int32_t blockTagOpIndex; +}; + #define GET_LLVM_LOWERING_PATTERNS #include "clang/CIR/Dialect/IR/CIRLowering.inc" #undef GET_LLVM_LOWERING_PATTERNS diff --git a/clang/utils/TableGen/CIRLoweringEmitter.cpp b/clang/utils/TableGen/CIRLoweringEmitter.cpp index c81b8941f9a39..4fa3502891ea3 100644 --- a/clang/utils/TableGen/CIRLoweringEmitter.cpp +++ b/clang/utils/TableGen/CIRLoweringEmitter.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TableGenBackends.h" +#include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include <string> #include <utility> @@ -23,6 +24,16 @@ namespace { std::vector<std::string> LLVMLoweringPatterns; std::vector<std::string> LLVMLoweringPatternsList; +struct CustomLoweringCtor { + struct Param { + std::string Type; + std::string Name; + }; + + std::vector<Param> Params; + std::string Body; +}; + // Adapted from mlir/lib/TableGen/Operator.cpp // Returns the C++ class name of the operation, which is the name of the // operation with the dialect prefix removed and the first underscore removed. @@ -51,10 +62,40 @@ std::string GetOpLLVMLoweringPatternName(llvm::StringRef OpName) { Name += "Lowering"; return Name; } +std::optional<CustomLoweringCtor> parseCustomLoweringCtor(const Record *R) { + if (!R) + return std::nullopt; + + CustomLoweringCtor Ctor; + const DagInit *Args = R->getValueAsDag("dagParams"); + Ctor.Body = R->getValueAsString("body"); + + for (const auto &[Arg, Name] : Args->getArgAndNames()) { + Ctor.Params.push_back( + {Arg->getAsUnquotedString(), Name->getAsUnquotedString()}); + } + + return Ctor; +} +void emitCustomParamList(raw_ostream &Code, + ArrayRef<CustomLoweringCtor::Param> Params) { + for (const auto &Param : Params) { + Code << ", "; + Code << Param.Type << " " << Param.Name; + } +} + +void emitCustomInitList(raw_ostream &Code, + ArrayRef<CustomLoweringCtor::Param> Params) { + for (auto &P : Params) + Code << ", " << P.Name << "(" << P.Name << ")"; +} void GenerateLLVMLoweringPattern(llvm::StringRef OpName, llvm::StringRef PatternName, bool IsRecursive, - llvm::StringRef ExtraDecl) { + llvm::StringRef ExtraDecl, + const Record *CustomCtorRec) { + auto CustomCtor = parseCustomLoweringCtor(CustomCtorRec); std::string CodeBuffer; llvm::raw_string_ostream Code(CodeBuffer); @@ -62,29 +103,46 @@ void GenerateLLVMLoweringPattern(llvm::StringRef OpName, << " : public mlir::OpConversionPattern<cir::" << OpName << "> {\n"; Code << " [[maybe_unused]] cir::LowerModule *lowerMod;\n"; Code << " [[maybe_unused]] mlir::DataLayout const &dataLayout;\n"; + + if (CustomCtor) { + for (auto &P : CustomCtor->Params) + Code << " " << P.Type << " " << P.Name << ";\n"; + } + Code << "\n"; Code << "public:\n"; Code << " using mlir::OpConversionPattern<cir::" << OpName << ">::OpConversionPattern;\n"; + // Constructor Code << " " << PatternName << "(mlir::TypeConverter const " "&typeConverter, mlir::MLIRContext *context, " "cir::LowerModule *lowerMod, mlir::DataLayout const " - "&dataLayout)\n"; + "&dataLayout"; + + if (CustomCtor) + emitCustomParamList(Code, CustomCtor->Params); + + Code << ")\n"; + Code << " : OpConversionPattern<cir::" << OpName << ">(typeConverter, context), lowerMod(lowerMod), " "dataLayout(dataLayout)"; - if (IsRecursive) { - Code << " {\n"; + + if (CustomCtor) + emitCustomInitList(Code, CustomCtor->Params); + + Code << " {\n"; + + if (IsRecursive) Code << " setHasBoundedRewriteRecursion();\n"; - Code << " }\n"; - } else { - Code << " {}\n"; - } - Code << "\n"; + if (CustomCtor) + Code << CustomCtor->Body << "\n"; + + Code << " }\n\n"; Code << " mlir::LogicalResult matchAndRewrite(cir::" << OpName << " op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) " @@ -106,11 +164,18 @@ void Generate(const Record *OpRecord) { if (OpRecord->getValueAsBit("hasLLVMLowering")) { std::string PatternName = GetOpLLVMLoweringPatternName(OpName); bool IsRecursive = OpRecord->getValueAsBit("isLLVMLoweringRecursive"); + const Record *CustomCtor = + OpRecord->getValueAsOptionalDef("customLLVMLoweringConstructorDecl"); llvm::StringRef ExtraDecl = OpRecord->getValueAsString("extraLLVMLoweringPatternDecl"); - GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl); - LLVMLoweringPatternsList.push_back(std::move(PatternName)); + GenerateLLVMLoweringPattern(OpName, PatternName, IsRecursive, ExtraDecl, + CustomCtor); + // Only automatically register patterns that use the default constructor. + // Patterns with a custom constructor must be manually registered by the + // lowering pass. + if (!CustomCtor) + LLVMLoweringPatternsList.push_back(std::move(PatternName)); } } } // namespace _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
