llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Andy Kaylor (andykaylor) <details> <summary>Changes</summary> The previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator. First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch. See https://github.com/llvm/llvm-project/pull/126332 This was done in the incubator, and here I am bringing that implementation upstream. The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two. --- Full diff: https://github.com/llvm/llvm-project/pull/129293.diff 3 Files Affected: - (removed) clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h (-52) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+75-53) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (-2) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h deleted file mode 100644 index bbba89cb7e3fd..0000000000000 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h +++ /dev/null @@ -1,52 +0,0 @@ -//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the CirAttrVisitor interface. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H -#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H - -#include "clang/CIR/Dialect/IR/CIRAttrs.h" - -namespace cir { - -template <typename ImplClass, typename RetTy> class CirAttrVisitor { -public: - // FIXME: Create a TableGen list to automatically handle new attributes - RetTy visit(mlir::Attribute attr) { - if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr)) - return getImpl().visitCirIntAttr(intAttr); - if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr)) - return getImpl().visitCirFPAttr(fltAttr); - if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr)) - return getImpl().visitCirConstPtrAttr(ptrAttr); - llvm_unreachable("unhandled attribute type"); - } - - // If the implementation chooses not to implement a certain visit - // method, fall back to the parent. - RetTy visitCirIntAttr(cir::IntAttr attr) { - return getImpl().visitCirAttr(attr); - } - RetTy visitCirFPAttr(cir::FPAttr attr) { - return getImpl().visitCirAttr(attr); - } - RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) { - return getImpl().visitCirAttr(attr); - } - - RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); } - - ImplClass &getImpl() { return *static_cast<ImplClass *>(this); } -}; - -} // namespace cir - -#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index ba7fab2865116..5d083efcdda6f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -24,10 +24,10 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" -#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/MissingFeatures.h" #include "clang/CIR/Passes.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Module.h" #include "llvm/Support/TimeProfiler.h" @@ -37,41 +37,23 @@ using namespace llvm; namespace cir { namespace direct { -class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> { +class CIRAttrToValue { public: CIRAttrToValue(mlir::Operation *parentOp, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter) : parentOp(parentOp), rewriter(rewriter), converter(converter) {} - mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); } - - mlir::Value visitCirIntAttr(cir::IntAttr intAttr) { - mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(intAttr.getType()), intAttr.getValue()); - } - - mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) { - mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); + mlir::Value visit(mlir::Attribute attr) { + return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) + .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) + .Default([&](auto attrT) { return mlir::Value(); }); } - mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) { - mlir::Location loc = parentOp->getLoc(); - if (ptrAttr.isNullValue()) { - return rewriter.create<mlir::LLVM::ZeroOp>( - loc, converter->convertType(ptrAttr.getType())); - } - mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); - mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( - loc, - rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), - ptrAttr.getValue().getInt()); - return rewriter.create<mlir::LLVM::IntToPtrOp>( - loc, converter->convertType(ptrAttr.getType()), ptrVal); - } + mlir::Value visitCirAttr(cir::IntAttr intAttr); + mlir::Value visitCirAttr(cir::FPAttr fltAttr); + mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); private: mlir::Operation *parentOp; @@ -79,21 +61,54 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> { const mlir::TypeConverter *converter; }; +/// IntAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(intAttr.getType()), intAttr.getValue()); +} + +/// ConstPtrAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { + mlir::Location loc = parentOp->getLoc(); + if (ptrAttr.isNullValue()) { + return rewriter.create<mlir::LLVM::ZeroOp>( + loc, converter->convertType(ptrAttr.getType())); + } + mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); + mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( + loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), + ptrAttr.getValue().getInt()); + return rewriter.create<mlir::LLVM::IntToPtrOp>( + loc, converter->convertType(ptrAttr.getType()), ptrVal); +} + +/// FPAttr visitor. +mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { + mlir::Location loc = parentOp->getLoc(); + return rewriter.create<mlir::LLVM::ConstantOp>( + loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); +} + // This class handles rewriting initializer attributes for types that do not // require region initialization. -class GlobalInitAttrRewriter - : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> { +class GlobalInitAttrRewriter { public: GlobalInitAttrRewriter(mlir::Type type, mlir::ConversionPatternRewriter &rewriter) : llvmType(type), rewriter(rewriter) {} - mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); } + mlir::Attribute visit(mlir::Attribute attr) { + return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr) + .Case<cir::IntAttr, cir::FPAttr>( + [&](auto attrT) { return visitCirAttr(attrT); }) + .Default([&](auto attrT) { return mlir::Attribute(); }); + } - mlir::Attribute visitCirIntAttr(cir::IntAttr attr) { + mlir::Attribute visitCirAttr(cir::IntAttr attr) { return rewriter.getIntegerAttr(llvmType, attr.getValue()); } - mlir::Attribute visitCirFPAttr(cir::FPAttr attr) { + mlir::Attribute visitCirAttr(cir::FPAttr attr) { return rewriter.getFloatAttr(llvmType, attr.getValue()); } @@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass StringRef getArgument() const override { return "cir-flat-to-llvm"; } }; -bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization( - mlir::Attribute attr) const { - // There will be more cases added later. - return isa<cir::ConstPtrAttr>(attr); -} - /// Replace CIR global with a region initialized LLVM global and update /// insertion point to the end of the initializer block. void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( @@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( // to the appropriate value. const mlir::Location loc = op.getLoc(); setupRegionInitializedLLVMGlobalOp(op, rewriter); - CIRAttrToValue attrVisitor(op, rewriter, typeConverter); - mlir::Value value = attrVisitor.lowerCirAttrAsValue(init); + CIRAttrToValue valueConverter(op, rewriter, typeConverter); + mlir::Value value = valueConverter.visit(init); rewriter.create<mlir::LLVM::ReturnOp>(loc, value); return mlir::success(); } @@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( std::optional<mlir::Attribute> init = op.getInitialValue(); - // If we have an initializer and it requires region initialization, handle - // that separately - if (init.has_value() && attrRequiresRegionInitialization(init.value())) { - return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); - } - // Fetch required values to create LLVM op. const mlir::Type cirSymType = op.getSymType(); @@ -218,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( SmallVector<mlir::NamedAttribute> attributes; if (init.has_value()) { - GlobalInitAttrRewriter initRewriter(llvmType, rewriter); - init = initRewriter.rewriteInitAttr(init.value()); - // If initRewriter returned a null attribute, init will have a value but - // the value will be null. If that happens, initRewriter didn't handle the - // attribute type. It probably needs to be added to GlobalInitAttrRewriter. - if (!init.value()) { + if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) { + // If a directly equivalent attribute is available, use it. + init = + llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value()) + .Case<cir::FPAttr>([&](cir::FPAttr attr) { + return rewriter.getFloatAttr(llvmType, attr.getValue()); + }) + .Case<cir::IntAttr>([&](cir::IntAttr attr) { + return rewriter.getIntegerAttr(llvmType, attr.getValue()); + }) + .Default([&](mlir::Attribute attr) { return mlir::Attribute(); }); + // If initRewriter returned a null attribute, init will have a value but + // the value will be null. + if (!init.value()) { + op.emitError() << "unsupported initializer '" << init.value() << "'"; + return mlir::failure(); + } + } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) { + // TODO(cir): once LLVM's dialect has proper equivalent attributes this + // should be updated. For now, we use a custom op to initialize globals + // to the appropriate value. + return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); + } else { + // We will only get here if new initializer types are added and this + // code is not updated to handle them. op.emitError() << "unsupported initializer '" << init.value() << "'"; return mlir::failure(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index b3366c1fb9337..d1109bb7e1c08 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering mlir::ConversionPatternRewriter &rewriter) const override; private: - bool attrRequiresRegionInitialization(mlir::Attribute attr) const; - mlir::LogicalResult matchAndRewriteRegionInitializedGlobal( cir::GlobalOp op, mlir::Attribute init, mlir::ConversionPatternRewriter &rewriter) const; `````````` </details> https://github.com/llvm/llvm-project/pull/129293 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits