Author: Christian Sigg Date: 2020-12-04T14:27:16+01:00 New Revision: dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947
URL: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947 DIFF: https://github.com/llvm/llvm-project/commit/dcec2ca5bd3d82ebbe57d47fc2bdd742d35e8947.diff LOG: Remove typeConverter from ConvertToLLVMPattern and use the existing one in ConversionPattern. ftynse Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D92564 Added: Modified: mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 7b8bcdff4deb..bf41f29749de 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -71,7 +71,7 @@ class LLVMTypeConverter : public TypeConverter { /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. - LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, + LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a @@ -485,6 +485,8 @@ class ConvertToLLVMPattern : public ConversionPattern { /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; + LLVMTypeConverter *getTypeConverter() const; + /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. LLVM::LLVMType getIndexType() const; @@ -556,10 +558,6 @@ class ConvertToLLVMPattern : public ConversionPattern { Value allocatedPtr, Value alignedPtr, ArrayRef<Value> sizes, ArrayRef<Value> strides, ConversionPatternRewriter &rewriter) const; - -protected: - /// Reference to the type converter, with potential extensions. - LLVMTypeConverter &typeConverter; }; /// Utility class for operation conversions targeting the LLVM dialect that @@ -644,7 +642,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - operands, this->typeConverter, + operands, *this->getTypeConverter(), rewriter); } }; @@ -666,9 +664,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, SourceOp>::value, "expected same operands and result type"); - return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(), - operands, this->typeConverter, - rewriter); + return LLVM::detail::vectorOneToOneRewrite( + op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index 3950562539f6..fe06e12c8f21 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -86,7 +86,7 @@ struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern { return failure(); return matchAndRewriteOneToOne<MaskRndScaleOp, LLVM::x86_avx512_mask_rndscale_ps_512>( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -103,7 +103,7 @@ struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern { return failure(); return matchAndRewriteOneToOne<MaskRndScaleOp, LLVM::x86_avx512_mask_rndscale_pd_512>( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -120,7 +120,7 @@ struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern { return failure(); return matchAndRewriteOneToOne<MaskScaleFOp, LLVM::x86_avx512_mask_scalef_ps_512>( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -137,7 +137,7 @@ struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern { return failure(); return matchAndRewriteOneToOne<MaskScaleFOp, LLVM::x86_avx512_mask_scalef_pd_512>( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp index ad84216d1e3b..810511194f68 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -72,7 +72,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} protected: - MLIRContext *context = &this->typeConverter.getContext(); + MLIRContext *context = &this->getTypeConverter()->getContext(); LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context); LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context); @@ -81,7 +81,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context); LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context); LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy( - context, this->typeConverter.getPointerBitwidth(0)); + context, this->getTypeConverter()->getPointerBitwidth(0)); FunctionCallBuilder moduleLoadCallBuilder = { "mgpuModuleLoad", @@ -333,8 +333,8 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); - auto arguments = - typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter); + auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(), + operands, rewriter); arguments.push_back(elementSize); hostRegisterCallBuilder.create(loc, rewriter, arguments); @@ -486,7 +486,7 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( OpBuilder &builder) const { auto loc = launchOp.getLoc(); auto numKernelOperands = launchOp.getNumKernelOperands(); - auto arguments = typeConverter.promoteOperands( + auto arguments = getTypeConverter()->promoteOperands( loc, launchOp.getOperands().take_back(numKernelOperands), operands.take_back(numKernelOperands), builder); auto numArguments = arguments.size(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index a3fad7e71c84..69ea393e5df1 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -41,7 +41,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { uint64_t numElements = type.getNumElements(); - auto elementType = typeConverter.convertType(type.getElementType()) + auto elementType = typeConverter->convertType(type.getElementType()) .template cast<LLVM::LLVMType>(); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); std::string name = std::string( @@ -54,14 +54,14 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { } // Rewrite the original GPU function to an LLVM function. - auto funcType = typeConverter.convertType(gpuFuncOp.getType()) + auto funcType = typeConverter->convertType(gpuFuncOp.getType()) .template cast<LLVM::LLVMType>() .getPointerElementTy(); // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); - typeConverter.convertFunctionSignature( + getTypeConverter()->convertFunctionSignature( gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion); // Create the new function operation. Only copy those attributes that are @@ -110,7 +110,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution.getType().cast<MemRefType>(); auto descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, type, memory); + rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); } @@ -127,7 +127,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). - auto ptrType = typeConverter.convertType(type.getElementType()) + auto ptrType = typeConverter->convertType(type.getElementType()) .template cast<LLVM::LLVMType>() .getPointerTo(AllocaAddrSpace); Value numElements = rewriter.create<LLVM::ConstantOp>( @@ -136,7 +136,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { Value allocated = rewriter.create<LLVM::AllocaOp>( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); auto descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, type, allocated); + rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( numProperArguments + numWorkgroupAttributions + en.index(), descr); } @@ -145,8 +145,8 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern { // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter, - &signatureConversion))) + if (failed(rewriter.convertRegionTypes( + &llvmFuncOp.getBody(), *typeConverter, &signatureConversion))) return failure(); rewriter.eraseOp(gpuFuncOp); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index f32c664c17c4..b907703995d8 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -135,8 +135,8 @@ class RangeOpConversion : public ConvertToLLVMPattern { matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast<RangeOp>(op); - auto rangeDescriptorTy = - convertRangeType(rangeOp.getType().cast<RangeType>(), typeConverter); + auto rangeDescriptorTy = convertRangeType( + rangeOp.getType().cast<RangeType>(), *getTypeConverter()); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -181,7 +181,7 @@ class ReshapeOpConversion : public ConvertToLLVMPattern { edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.src()); - BaseViewConversionHelper desc(typeConverter.convertType(dstType)); + BaseViewConversionHelper desc(typeConverter->convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); desc.setAlignedPtr(baseDesc.alignedPtr()); desc.setOffset(baseDesc.offset()); @@ -214,11 +214,11 @@ class SliceOpConversion : public ConvertToLLVMPattern { auto sliceOp = cast<SliceOp>(op); auto memRefType = sliceOp.getBaseViewType(); - auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) + auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)) .cast<LLVM::LLVMType>(); BaseViewConversionHelper desc( - typeConverter.convertType(sliceOp.getShapedType())); + typeConverter->convertType(sliceOp.getShapedType())); // TODO: extract sizes and emit asserts. SmallVector<Value, 4> strides(memRefType.getRank()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 91e97ca1ec50..c589ef69f2c4 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -35,7 +35,7 @@ struct RegionOpConversion : public ConvertToLLVMPattern { curOp.getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); - if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) + if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter))) return failure(); rewriter.eraseOp(op); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index 525a5be24485..f83f72d1d10e 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -224,7 +224,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); - auto dstGlobalType = typeConverter.convertType(pointeeType); + auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); std::string name = diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index f54ffc1c9d6c..17a065463297 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -446,8 +446,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) - : ConversionPattern(rootOpName, benefit, typeConverter, context), - typeConverter(typeConverter) {} + : ConversionPattern(rootOpName, benefit, typeConverter, context) {} //===----------------------------------------------------------------------===// // StructBuilder implementation @@ -1013,27 +1012,32 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); } +LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { + return static_cast<LLVMTypeConverter *>( + ConversionPattern::getTypeConverter()); +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { - return *typeConverter.getDialect(); + return *getTypeConverter()->getDialect(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { - return typeConverter.getIndexType(); + return getTypeConverter()->getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMType::getIntNTy( - &typeConverter.getContext(), - typeConverter.getPointerBitwidth(addressSpace)); + &getTypeConverter()->getContext(), + getTypeConverter()->getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { - return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); + return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { - return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext()); + return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext()); } Value ConvertToLLVMPattern::createIndexConstant( @@ -1086,7 +1090,7 @@ Value ConvertToLLVMPattern::getDataPtr( // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { - if (!typeConverter.convertType(type.getElementType())) + if (!typeConverter->convertType(type.getElementType())) return false; return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), @@ -1095,7 +1099,7 @@ bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); - auto structElementType = unwrap(typeConverter.convertType(elementType)); + auto structElementType = unwrap(typeConverter->convertType(elementType)); return structElementType.getPointerTo(type.getMemorySpace()); } @@ -1155,7 +1159,7 @@ Value ConvertToLLVMPattern::getSizeInBytes( // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto convertedPtrType = - typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo(); + typeConverter->convertType(type).cast<LLVM::LLVMType>().getPointerTo(); auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); auto gep = rewriter.create<LLVM::GEPOp>( loc, convertedPtrType, @@ -1179,7 +1183,7 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef<Value> sizes, ArrayRef<Value> strides, ConversionPatternRewriter &rewriter) const { - auto structType = typeConverter.convertType(memRefType); + auto structType = typeConverter->convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. @@ -1347,7 +1351,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); - auto llvmType = typeConverter.convertFunctionSignature( + auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; @@ -1379,7 +1383,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> { attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; @@ -1402,14 +1406,14 @@ struct FuncOpConversion : public FuncOpConversionBase { if (!newFuncOp) return failure(); - if (typeConverter.getOptions().emitCWrappers || + if (getTypeConverter()->getOptions().emitCWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp, - newFuncOp); + wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); else - wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp, - newFuncOp); + wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); } rewriter.eraseOp(funcOp); @@ -1472,7 +1476,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { rewriter.replaceUsesOfBlockArgument(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, memrefTy, arg); + rewriter, loc, *getTypeConverter(), memrefTy, arg); rewriter.replaceOp(placeholder, {desc}); } @@ -1757,7 +1761,7 @@ struct CreateComplexOpLowering // Pack real and imaginary part in a complex number struct. auto loc = op.getLoc(); - auto structType = typeConverter.convertType(complexOp.getType()); + auto structType = typeConverter->convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); @@ -1836,7 +1840,7 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> { unpackBinaryComplexOperands<AddCFOp>(op, operands, rewriter); // Initialize complex number struct for result. - auto structType = this->typeConverter.convertType(op.getType()); + auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. @@ -1863,7 +1867,7 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> { unpackBinaryComplexOperands<SubCFOp>(op, operands, rewriter); // Initialize complex number struct for result. - auto structType = this->typeConverter.convertType(op.getType()); + auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. @@ -1887,7 +1891,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> { ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) { - auto type = typeConverter.convertType(op.getResult().getType()) + auto type = typeConverter->convertType(op.getResult().getType()) .dyn_cast_or_null<LLVM::LLVMType>(); if (!type) return rewriter.notifyMatchFailure(op, "failed to convert result type"); @@ -1905,9 +1909,9 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> { return rewriter.notifyMatchFailure( op, "referring to a symbol outside of the current module"); - return LLVM::detail::oneToOneRewrite(op, - LLVM::ConstantOp::getOperationName(), - operands, typeConverter, rewriter); + return LLVM::detail::oneToOneRewrite( + op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(), + rewriter); } }; @@ -1916,7 +1920,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; - using ConvertToLLVMPattern::typeConverter; explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} @@ -2288,11 +2291,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { if (numResults != 0) { if (!(packedResult = - this->typeConverter.packFunctionResults(resultTypes))) + this->getTypeConverter()->packFunctionResults(resultTypes))) return failure(); } - auto promoted = this->typeConverter.promoteOperands( + auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands, rewriter); auto newOp = rewriter.create<LLVM::CallOp>( @@ -2309,23 +2312,23 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = - this->typeConverter.convertType(callOp.getResult(i).getType()); + this->typeConverter->convertType(callOp.getResult(i).getType()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( callOp.getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } } - if (this->typeConverter.getOptions().useBarePtrCallConv) { + if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); - this->typeConverter.promoteBarePtrsToDescriptors( + this->getTypeConverter()->promoteBarePtrsToDescriptors( rewriter, callOp.getLoc(), resultTypes, results); } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(), - this->typeConverter, resultTypes, - results, + *this->getTypeConverter(), + resultTypes, results, /*toDynamic=*/false))) { return failure(); } @@ -2410,7 +2413,8 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> { if (!isSupportedMemRefType(type)) return failure(); - LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + LLVM::LLVMType arrayTy = + convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; @@ -2449,14 +2453,15 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { MemRefType type = getGlobalOp.result().getType().cast<MemRefType>(); unsigned memSpace = type.getMemorySpace(); - LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + LLVM::LLVMType arrayTy = + convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create<LLVM::AddressOfOp>( loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. LLVM::LLVMType elementType = - unwrap(typeConverter.convertType(type.getElementType())); + unwrap(typeConverter->convertType(type.getElementType())); LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); SmallVector<Value, 4> operands = {addressOf}; @@ -2517,7 +2522,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> { return failure(); return handleMultidimensionalVectors( - op.getOperation(), operands, typeConverter, + op.getOperation(), operands, *getTypeConverter(), [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, @@ -2546,8 +2551,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> { // a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) - return success(typeConverter.convertType(srcType) == - typeConverter.convertType(dstType)); + return success(typeConverter->convertType(srcType) == + typeConverter->convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa<UnrankedMemRefType>() || @@ -2566,7 +2571,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> { auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); - auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); + auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. @@ -2581,7 +2586,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> { auto srcMemRefType = srcType.cast<MemRefType>(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) - auto ptr = typeConverter.promoteOneMemRefDescriptor( + auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = @@ -2589,7 +2594,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> { .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(rewriter.getIntegerType(64)), + loc, typeConverter->convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = @@ -2693,7 +2698,7 @@ struct MemRefReinterpretCastOpLowering Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast<MemRefType>(); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); @@ -2704,8 +2709,9 @@ struct MemRefReinterpretCastOpLowering // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; - extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), - adaptor.source(), &allocatedPtr, &alignedPtr); + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + castOp.source(), adaptor.source(), &allocatedPtr, + &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); @@ -2779,10 +2785,10 @@ struct MemRefReshapeOpLowering // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( - rewriter, loc, unwrap(typeConverter.convertType(targetType))); + rewriter, loc, unwrap(typeConverter->convertType(targetType))); targetDesc.setRank(rewriter, loc, resultRank); SmallVector<Value, 4> sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, sizes); Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( loc, getVoidPtrType(), sizes.front(), llvm::None); @@ -2790,37 +2796,38 @@ struct MemRefReshapeOpLowering // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; - extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), - adaptor.source(), &allocatedPtr, &alignedPtr, - &offset); + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + reshapeOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. LLVM::LLVMType llvmElementType = - unwrap(typeConverter.convertType(elementType)); + unwrap(typeConverter->convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(addressSpace).getPointerTo(); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); - UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, alignedPtr); - UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( - rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + rewriter, loc, *getTypeConverter(), underlyingDescPtr, + elementPtrPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( - rewriter, loc, typeConverter, targetSizesBase, resultRank); + rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexConstant(rewriter, loc, 1); Value resultRankMinusOne = rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); - LLVM::LLVMType indexType = typeConverter.getIndexType(); + LLVM::LLVMType indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, @@ -2854,11 +2861,11 @@ struct MemRefReshapeOpLowering Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); - UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. - UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); @@ -2892,7 +2899,7 @@ struct DialectCastOpLowering ConversionPatternRewriter &rewriter) const override { LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != - typeConverter.convertType(castOp.getType())) { + typeConverter->convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(castOp, transformed.in()); @@ -2942,15 +2949,16 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> { Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>( loc, - typeConverter.convertType(scalarMemRefType) + typeConverter->convertType(scalarMemRefType) .cast<LLVM::LLVMType>() .getPointerTo(addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref<element_type> descriptor. - Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace); + Type indexPtrTy = + getTypeConverter()->getIndexType().getPointerTo(addressSpace); Value two = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(rewriter.getI32Type()), + loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create<LLVM::GEPOp>( loc, indexPtrTy, scalarMemRefDescPtr, @@ -3082,7 +3090,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> { transformed.indices(), rewriter); // Replace with llvm.prefetch. - auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); + auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create<LLVM::ConstantOp>( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create<LLVM::ConstantOp>( @@ -3110,7 +3118,7 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> { IndexCastOpAdaptor transformed(operands); auto targetType = - this->typeConverter.convertType(indexCastOp.getResult().getType()) + typeConverter->convertType(indexCastOp.getResult().getType()) .cast<LLVM::LLVMType>(); auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>(); unsigned targetBits = targetType.getIntegerBitWidth(); @@ -3144,7 +3152,7 @@ struct CmpIOpLowering : public ConvertOpToLLVMPattern<CmpIOp> { CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( - cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()), + cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast<int64_t>( convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3162,7 +3170,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> { CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( - cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()), + cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast<int64_t>( convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3248,7 +3256,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> { unsigned numArguments = op.getNumOperands(); SmallVector<Value, 4> updatedOperands; - if (typeConverter.getOptions().useBarePtrCallConv) { + if (getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), operands)) { @@ -3266,7 +3274,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> { } } else { updatedOperands = llvm::to_vector<4>(operands); - copyUnrankedDescriptors(rewriter, loc, typeConverter, + copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(), op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } @@ -3285,7 +3293,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<ReturnOp> { // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto packedType = typeConverter.packFunctionResults( + auto packedType = getTypeConverter()->packFunctionResults( llvm::to_vector<4>(op.getOperandTypes())); Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); @@ -3323,11 +3331,11 @@ struct SplatOpLowering : public ConvertOpToLLVMPattern<SplatOp> { return failure(); // First insert it into an undef vector so we can shuffle it. - auto vectorType = typeConverter.convertType(splatOp.getType()); + auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType); auto zero = rewriter.create<LLVM::ConstantOp>( splatOp.getLoc(), - typeConverter.convertType(rewriter.getIntegerType(32)), + typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create<LLVM::InsertElementOp>( @@ -3360,7 +3368,8 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { // First insert it into an undef vector so we can shuffle it. auto loc = splatOp.getLoc(); - auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); + auto vectorTypeInfo = + extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) @@ -3373,7 +3382,7 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> { // places within the returned descriptor. Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy); auto zero = rewriter.create<LLVM::ConstantOp>( - loc, typeConverter.convertType(rewriter.getIntegerType(32)), + loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc, adaptor.input(), zero); @@ -3418,7 +3427,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> { auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>(); auto sourceElementTy = - typeConverter.convertType(sourceMemRefType.getElementType()) + typeConverter->convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null<LLVM::LLVMType>(); auto viewMemRefType = subViewOp.getType(); @@ -3429,9 +3438,9 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> { extractFromI64ArrayAttr(subViewOp.static_strides())) .cast<MemRefType>(); auto targetElementTy = - typeConverter.convertType(viewMemRefType.getElementType()) + typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast<LLVM::LLVMType>(); - auto targetDescTy = typeConverter.convertType(viewMemRefType) + auto targetDescTy = typeConverter->convertType(viewMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!sourceElementTy || !targetDescTy) return failure(); @@ -3477,7 +3486,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> { strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. - auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); + auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { @@ -3553,7 +3562,7 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<TransposeOp> { return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( - rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); + rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. @@ -3629,10 +3638,10 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> { auto viewMemRefType = viewOp.getType(); auto targetElementTy = - typeConverter.convertType(viewMemRefType.getElementType()) + typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast<LLVM::LLVMType>(); auto targetDescTy = - typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>(); + typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>(); if (!targetDescTy) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); @@ -3825,7 +3834,7 @@ struct GenericAtomicRMWOpLowering auto loc = atomicOp.getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = - typeConverter.convertType(atomicOp.getResult().getType()) + typeConverter->convertType(atomicOp.getResult().getType()) .cast<LLVM::LLVMType>(); // Split the block into initial, loop, and ending parts. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index b3fa315b75a3..85d3e2bddd66 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -309,7 +309,7 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern { auto matmulOp = cast<vector::MatmulOp>(op); auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>( - op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), + op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), matmulOp.rhs_columns()); return success(); @@ -331,7 +331,7 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { auto transOp = cast<vector::FlatTransposeOp>(op); auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>( - transOp, typeConverter.convertType(transOp.res().getType()), + transOp, typeConverter->convertType(transOp.res().getType()), adaptor.matrix(), transOp.rows(), transOp.columns()); return success(); } @@ -354,10 +354,10 @@ class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, load, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), load, align))) return failure(); - auto vtype = typeConverter.convertType(load.getResultVectorType()); + auto vtype = typeConverter->convertType(load.getResultVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), vtype, ptr))) @@ -387,10 +387,10 @@ class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, store, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), store, align))) return failure(); - auto vtype = typeConverter.convertType(store.getValueVectorType()); + auto vtype = typeConverter->convertType(store.getValueVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), vtype, ptr))) @@ -420,7 +420,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern { // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, gather, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), gather, align))) return failure(); // Get index ptrs. @@ -433,7 +433,7 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern { // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_gather>( - gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), + gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); return success(); } @@ -456,7 +456,7 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern { // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, scatter, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align))) return failure(); // Get index ptrs. @@ -497,7 +497,7 @@ class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { auto vType = expand.getResultVectorType(); rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( - op, typeConverter.convertType(vType), ptr, adaptor.mask(), + op, typeConverter->convertType(vType), ptr, adaptor.mask(), adaptor.pass_thru()); return success(); } @@ -545,7 +545,7 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern { auto reductionOp = cast<vector::ReductionOp>(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); - Type llvmType = typeConverter.convertType(eltType); + Type llvmType = typeConverter->convertType(eltType); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") @@ -580,39 +580,40 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern { else return failure(); return success(); - - } else if (eltType.isa<FloatType>()) { - // Floating-point reductions: add/mul/min/max - if (kind == "add") { - // Optional accumulator (or zero). - Value acc = operands.size() > 1 ? operands[1] - : rewriter.create<LLVM::ConstantOp>( - op->getLoc(), llvmType, - rewriter.getZeroAttr(eltType)); - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( - op, llvmType, acc, operands[0], - rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == "mul") { - // Optional accumulator (or one). - Value acc = operands.size() > 1 - ? operands[1] - : rewriter.create<LLVM::ConstantOp>( - op->getLoc(), llvmType, - rewriter.getFloatAttr(eltType, 1.0)); - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( - op, llvmType, acc, operands[0], - rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == "min") - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>( - op, llvmType, operands[0]); - else if (kind == "max") - rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>( - op, llvmType, operands[0]); - else - return failure(); - return success(); } - return failure(); + + if (!eltType.isa<FloatType>()) + return failure(); + + // Floating-point reductions: add/mul/min/max + if (kind == "add") { + // Optional accumulator (or zero). + Value acc = operands.size() > 1 ? operands[1] + : rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmType, + rewriter.getZeroAttr(eltType)); + rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>( + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); + } else if (kind == "mul") { + // Optional accumulator (or one). + Value acc = operands.size() > 1 + ? operands[1] + : rewriter.create<LLVM::ConstantOp>( + op->getLoc(), llvmType, + rewriter.getFloatAttr(eltType, 1.0)); + rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>( + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); + } else if (kind == "min") + rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(op, llvmType, + operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(op, llvmType, + operands[0]); + else + return failure(); + return success(); } private: @@ -663,7 +664,7 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern { auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); - Type llvmType = typeConverter.convertType(vectorType); + Type llvmType = typeConverter->convertType(vectorType); auto maskArrayAttr = shuffleOp.mask(); // Bail if result type cannot be lowered. @@ -695,9 +696,9 @@ class VectorShuffleOpConversion : public ConvertToLLVMPattern { extPos -= v1Dim; value = adaptor.v2(); } - Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, - rank, extPos); - insert = insertOne(rewriter, typeConverter, loc, insert, extract, + Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, + llvmType, rank, extPos); + insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } rewriter.replaceOp(op, insert); @@ -718,7 +719,7 @@ class VectorExtractElementOpConversion : public ConvertToLLVMPattern { auto adaptor = vector::ExtractElementOpAdaptor(operands); auto extractEltOp = cast<vector::ExtractElementOp>(op); auto vectorType = extractEltOp.getVectorType(); - auto llvmType = typeConverter.convertType(vectorType.getElementType()); + auto llvmType = typeConverter->convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. if (!llvmType) @@ -745,7 +746,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern { auto extractOp = cast<vector::ExtractOp>(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); - auto llvmResultType = typeConverter.convertType(resultType); + auto llvmResultType = typeConverter->convertType(resultType); auto positionArrayAttr = extractOp.position(); // Bail if result type cannot be lowered. @@ -769,7 +770,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern { auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create<LLVM::ExtractValueOp>( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } @@ -833,7 +834,7 @@ class VectorInsertElementOpConversion : public ConvertToLLVMPattern { auto adaptor = vector::InsertElementOpAdaptor(operands); auto insertEltOp = cast<vector::InsertElementOp>(op); auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter.convertType(vectorType); + auto llvmType = typeConverter->convertType(vectorType); // Bail if result type cannot be lowered. if (!llvmType) @@ -860,7 +861,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern { auto insertOp = cast<vector::InsertOp>(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); - auto llvmResultType = typeConverter.convertType(destVectorType); + auto llvmResultType = typeConverter->convertType(destVectorType); auto positionArrayAttr = insertOp.position(); // Bail if result type cannot be lowered. @@ -887,7 +888,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern { auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create<LLVM::ExtractValueOp>( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } @@ -895,7 +896,7 @@ class VectorInsertOpConversion : public ConvertToLLVMPattern { auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position); Value inserted = rewriter.create<LLVM::InsertElementOp>( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, adaptor.source(), constant); // Potential insertion of resulting 1-D vector into array. @@ -1000,7 +1001,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern Value extracted = rewriter.create<ExtractOp>(loc, op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropFront=*/rankRest)); + /*dropBack=*/rankRest)); // A diff erent pattern will kick in for InsertStridedSlice with matching // ranks. auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( @@ -1010,7 +1011,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern rewriter.replaceOpWithNewOp<InsertOp>( op, stridedSliceInnerOp.getResult(), op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropFront=*/rankRest)); + /*dropBack=*/rankRest)); return success(); } }; @@ -1144,7 +1145,7 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern { return failure(); MemRefDescriptor sourceMemRef(operands[0]); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null<LLVM::LLVMType>(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); @@ -1234,7 +1235,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern { if (!strides) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; Location loc = op->getLoc(); MemRefType memRefType = xferOp.getMemRefType(); @@ -1279,8 +1280,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern { loc, vecTy.getPointerTo(), dataPtr); if (!xferOp.isMaskedDim(0)) - return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, - xferOp, operands, vectorDataPtr); + return replaceTransferOpWithLoadOrStore( + rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr); // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. @@ -1297,8 +1298,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern { vecWidth, dim, &off); // 5. Rewrite as a masked read / write. - return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, - operands, vectorDataPtr, mask); + return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc, + xferOp, operands, vectorDataPtr, mask); } private: @@ -1331,7 +1332,7 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern { auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); - if (typeConverter.convertType(printType) == nullptr) + if (typeConverter->convertType(printType) == nullptr) return failure(); // Make sure element type has runtime support. @@ -1421,10 +1422,10 @@ class VectorPrintOpConversion : public ConvertToLLVMPattern { for (int64_t d = 0; d < dim; ++d) { auto reducedType = rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; - auto llvmType = typeConverter.convertType( + auto llvmType = typeConverter->convertType( rank > 1 ? reducedType : vectorType.getElementType()); - Value nestedVal = - extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); + Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, + llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, conversion); if (d != dim - 1) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index 26b8bec1f3fc..61f094746a0a 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -79,7 +79,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern { if (!xferOp.isMaskedDim(0)) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; LLVM::LLVMType vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>(); unsigned vecWidth = vecTy.getVectorNumElements(); @@ -142,9 +142,9 @@ class VectorTransferConversion : public ConvertToLLVMPattern { Value int32Zero = rewriter.create<LLVM::ConstantOp>( loc, toLLVMTy(i32Ty), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); - return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc, - xferOp, vecTy, dwordConfig, int32Zero, - int32Zero, int1False, int1False); + return replaceTransferOpWithMubuf( + rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy, + dwordConfig, int32Zero, int32Zero, int1False, int1False); } }; } // end anonymous namespace _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits