https://github.com/andykaylor updated https://github.com/llvm/llvm-project/pull/129293
>From bb41af68d0d0f66c5610c69d6deb8a615d644fe5 Mon Sep 17 00:00:00 2001 From: Andy Kaylor <akay...@nvidia.com> Date: Fri, 28 Feb 2025 10:54:09 -0800 Subject: [PATCH 1/3] [CIR] Replace CIRAttrVisitor with TypeSwitch We previously discussed having an mlir-tblgen utility to complete the CIRAttrVisitor implementation with all support attribute types, but when I proposed an implementation to do this, a reviewer suggested using TypeSwitch instead, and I have done that in the incubator. See https://github.com/llvm/llvm-project/pull/126332 This change brings the TypeSwitch implementation into the upstream repo to replace the visitor class. --- .../clang/CIR/Dialect/IR/CIRAttrVisitor.h | 52 ------------------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 31 +++++++---- 2 files changed, 22 insertions(+), 61 deletions(-) delete mode 100644 clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h deleted file mode 100644 index bbba89cb7e3fd..0000000000000 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h +++ /dev/null @@ -1,52 +0,0 @@ -//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the CirAttrVisitor interface. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H -#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H - -#include "clang/CIR/Dialect/IR/CIRAttrs.h" - -namespace cir { - -template <typename ImplClass, typename RetTy> class CirAttrVisitor { -public: - // FIXME: Create a TableGen list to automatically handle new attributes - RetTy visit(mlir::Attribute attr) { - if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) - return getImpl().visitCirIntAttr(intAttr); - if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr)) - return getImpl().visitCirFPAttr(fltAttr); - if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr)) - return getImpl().visitCirConstPtrAttr(ptrAttr); - llvm_unreachable("unhandled attribute type"); - } - - // If the implementation chooses not to implement a certain visit - // method, fall back to the parent. - RetTy visitCirIntAttr(cir::IntAttr attr) { - return getImpl().visitCirAttr(attr); - } - RetTy visitCirFPAttr(cir::FPAttr attr) { - return getImpl().visitCirAttr(attr); - } - RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) { - return getImpl().visitCirAttr(attr); - } - - RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); } - - ImplClass &getImpl() { return *static_cast<ImplClass *>(this); } -}; - -} // namespace cir - -#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index ba7fab2865116..7bf4b5fd27b61 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -24,10 +24,10 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" -#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/MissingFeatures.h" #include "clang/CIR/Passes.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Module.h" #include "llvm/Support/TimeProfiler.h" @@ -37,7 +37,7 @@ using namespace llvm; namespace cir { namespace direct { -class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> { +class CIRAttrToValue { public: CIRAttrToValue(mlir::Operation *parentOp, mlir::ConversionPatternRewriter &rewriter, @@ -46,19 +46,26 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> { mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); } - mlir::Value visitCirIntAttr(cir::IntAttr intAttr) { + mlir::Value visit(mlir::Attribute attr) { + return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) + .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) + .Default([&](auto attrT) { return mlir::Value(); }); + } + + mlir::Value visitCirAttr(cir::IntAttr intAttr) { mlir::Location loc = parentOp->getLoc(); return rewriter.create<mlir::LLVM::ConstantOp>( loc, converter->convertType(intAttr.getType()), intAttr.getValue()); } - mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) { + mlir::Value visitCirAttr(cir::FPAttr fltAttr) { mlir::Location loc = parentOp->getLoc(); return rewriter.create<mlir::LLVM::ConstantOp>( loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); } - mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) { + mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) { mlir::Location loc = parentOp->getLoc(); if (ptrAttr.isNullValue()) { return rewriter.create<mlir::LLVM::ZeroOp>( @@ -81,8 +88,7 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> { // This class handles rewriting initializer attributes for types that do not // require region initialization. -class GlobalInitAttrRewriter - : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> { +class GlobalInitAttrRewriter { public: GlobalInitAttrRewriter(mlir::Type type, mlir::ConversionPatternRewriter &rewriter) @@ -90,10 +96,17 @@ class GlobalInitAttrRewriter mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); } - mlir::Attribute visitCirIntAttr(cir::IntAttr attr) { + mlir::Attribute visit(mlir::Attribute attr) { + return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr) + .Case<cir::IntAttr, cir::FPAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) + .Default([&](auto attrT) { return mlir::Attribute(); }); + } + + mlir::Attribute visitCirAttr(cir::IntAttr attr) { return rewriter.getIntegerAttr(llvmType, attr.getValue()); } - mlir::Attribute visitCirFPAttr(cir::FPAttr attr) { + mlir::Attribute visitCirAttr(cir::FPAttr attr) { return rewriter.getFloatAttr(llvmType, attr.getValue()); } >From cba15182d96bbdc28009501256fa93951089e70e Mon Sep 17 00:00:00 2001 From: Andy Kaylor <akay...@nvidia.com> Date: Fri, 28 Feb 2025 11:09:40 -0800 Subject: [PATCH 2/3] Re-align lowering code with incubator implementation --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 109 ++++++++++-------- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 2 - 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 7bf4b5fd27b61..5d083efcdda6f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -44,8 +44,6 @@ class CIRAttrToValue { const mlir::TypeConverter *converter) : parentOp(parentOp), rewriter(rewriter), converter(converter) {} - mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); } - mlir::Value visit(mlir::Attribute attr) { return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>( @@ -53,32 +51,9 @@ class CIRAttrToValue { .Default([&](auto attrT) { return mlir::Value(); }); } - mlir::Value visitCirAttr(cir::IntAttr intAttr) { - mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(intAttr.getType()), intAttr.getValue()); - } - - mlir::Value visitCirAttr(cir::FPAttr fltAttr) { - mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); - } - - mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr) { - mlir::Location loc = parentOp->getLoc(); - if (ptrAttr.isNullValue()) { - return rewriter.create<mlir::LLVM::ZeroOp>( - loc, converter->convertType(ptrAttr.getType())); - } - mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); - mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( - loc, - rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), - ptrAttr.getValue().getInt()); - return rewriter.create<mlir::LLVM::IntToPtrOp>( - loc, converter->convertType(ptrAttr.getType()), ptrVal); - } + mlir::Value visitCirAttr(cir::IntAttr intAttr); + mlir::Value visitCirAttr(cir::FPAttr fltAttr); + mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); private: mlir::Operation *parentOp; @@ -86,6 +61,35 @@ class CIRAttrToValue { const mlir::TypeConverter *converter; }; +/// IntAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(intAttr.getType()), intAttr.getValue()); +} + +/// ConstPtrAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { + mlir::Location loc = parentOp->getLoc(); + if (ptrAttr.isNullValue()) { + return rewriter.create<mlir::LLVM::ZeroOp>( + loc, converter->convertType(ptrAttr.getType())); + } + mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); + mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( + loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), + ptrAttr.getValue().getInt()); + return rewriter.create<mlir::LLVM::IntToPtrOp>( + loc, converter->convertType(ptrAttr.getType()), ptrVal); +} + +/// FPAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); +} + // This class handles rewriting initializer attributes for types that do not // require region initialization. class GlobalInitAttrRewriter { @@ -94,8 +98,6 @@ class GlobalInitAttrRewriter { mlir::ConversionPatternRewriter &rewriter) : llvmType(type), rewriter(rewriter) {} - mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); } - mlir::Attribute visit(mlir::Attribute attr) { return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr) .Case<cir::IntAttr, cir::FPAttr>( @@ -137,12 +139,6 @@ struct ConvertCIRToLLVMPass StringRef getArgument() const override { return "cir-flat-to-llvm"; } }; -bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization( - mlir::Attribute attr) const { - // There will be more cases added later. - return isa<cir::ConstPtrAttr>(attr); -} - /// Replace CIR global with a region initialized LLVM global and update /// insertion point to the end of the initializer block. void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( @@ -189,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( // to the appropriate value. const mlir::Location loc = op.getLoc(); setupRegionInitializedLLVMGlobalOp(op, rewriter); - CIRAttrToValue attrVisitor(op, rewriter, typeConverter); - mlir::Value value = attrVisitor.lowerCirAttrAsValue(init); + CIRAttrToValue valueConverter(op, rewriter, typeConverter); + mlir::Value value = valueConverter.visit(init); rewriter.create<mlir::LLVM::ReturnOp>(loc, value); return mlir::success(); } @@ -201,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( std::optional<mlir::Attribute> init = op.getInitialValue(); - // If we have an initializer and it requires region initialization, handle - // that separately - if (init.has_value() && attrRequiresRegionInitialization(init.value())) { - return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); - } - // Fetch required values to create LLVM op. const mlir::Type cirSymType = op.getSymType(); @@ -231,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( SmallVector<mlir::NamedAttribute> attributes; if (init.has_value()) { - GlobalInitAttrRewriter initRewriter(llvmType, rewriter); - init = initRewriter.rewriteInitAttr(init.value()); - // If initRewriter returned a null attribute, init will have a value but - // the value will be null. If that happens, initRewriter didn't handle the - // attribute type. It probably needs to be added to GlobalInitAttrRewriter. - if (!init.value()) { + if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) { + // If a directly equivalent attribute is available, use it. + init = + llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value()) + .Case<cir::FPAttr>([&](cir::FPAttr attr) { + return rewriter.getFloatAttr(llvmType, attr.getValue()); + }) + .Case<cir::IntAttr>([&](cir::IntAttr attr) { + return rewriter.getIntegerAttr(llvmType, attr.getValue()); + }) + .Default([&](mlir::Attribute attr) { return mlir::Attribute(); }); + // If initRewriter returned a null attribute, init will have a value but + // the value will be null. + if (!init.value()) { + op.emitError() << "unsupported initializer '" << init.value() << "'"; + return mlir::failure(); + } + } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) { + // TODO(cir): once LLVM's dialect has proper equivalent attributes this + // should be updated. For now, we use a custom op to initialize globals + // to the appropriate value. + return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); + } else { + // We will only get here if new initializer types are added and this + // code is not updated to handle them. op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index b3366c1fb9337..d1109bb7e1c08 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering mlir::ConversionPatternRewriter &rewriter) const override; private: - bool attrRequiresRegionInitialization(mlir::Attribute attr) const; - mlir::LogicalResult matchAndRewriteRegionInitializedGlobal( cir::GlobalOp op, mlir::Attribute init, mlir::ConversionPatternRewriter &rewriter) const; >From c76fb372cdae335a182f60f8f9c6dea207e8bd1d Mon Sep 17 00:00:00 2001 From: Andy Kaylor <akay...@nvidia.com> Date: Fri, 28 Feb 2025 13:34:57 -0800 Subject: [PATCH 3/3] Restore use of GlobalInitAttrRewriter --- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 5d083efcdda6f..6f7cae8fa7fa3 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -222,18 +222,12 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( if (init.has_value()) { if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) { - // If a directly equivalent attribute is available, use it. - init = - llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value()) - .Case<cir::FPAttr>([&](cir::FPAttr attr) { - return rewriter.getFloatAttr(llvmType, attr.getValue()); - }) - .Case<cir::IntAttr>([&](cir::IntAttr attr) { - return rewriter.getIntegerAttr(llvmType, attr.getValue()); - }) - .Default([&](mlir::Attribute attr) { return mlir::Attribute(); }); + GlobalInitAttrRewriter initRewriter(llvmType, rewriter); + init = initRewriter.visit(init.value()); // If initRewriter returned a null attribute, init will have a value but - // the value will be null. + // the value will be null. If that happens, initRewriter didn't handle the + // attribute type. It probably needs to be added to + // GlobalInitAttrRewriter. if (!init.value()) { op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits