https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116470
>From fe38d4bc65947e7d33854f40927bfdde7aa5186b Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Tue, 12 Nov 2024 05:14:43 +0100 Subject: [PATCH] replace with multiple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply suggestions from code review Co-authored-by: Markus Böck <markus.boec...@gmail.com> address comments [WIP] 1:N conversion pattern update test cases --- .../mlir/Conversion/LLVMCommon/Pattern.h | 35 ++- .../mlir/Transforms/DialectConversion.h | 63 +++++ .../Transforms/DecomposeCallGraphTypes.cpp | 56 +--- .../Func/Transforms/FuncConversions.cpp | 5 +- .../Transforms/StructuralTypeConversions.cpp | 106 +++----- .../Transforms/SparseTensorCodegen.cpp | 114 ++++---- .../Transforms/Utils/SparseTensorDescriptor.h | 16 +- .../Transforms/Utils/DialectConversion.cpp | 251 ++++++++++++------ .../decompose-call-graph-types.mlir | 38 +-- 9 files changed, 381 insertions(+), 303 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f3bf5b66398e09..86ea87b55af1cd 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -143,6 +143,8 @@ template <typename SourceOp> class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>; explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) @@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)), - rewriter); + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + void rewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); } LogicalResult match(Operation *op) const final { return match(cast<SourceOp>(op)); @@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast<SourceOp>(op), - OpAdaptor(operands, cast<SourceOp>(op)), rewriter); + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be @@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConvertToLLVMPattern::match; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index de47765006f81e..e4eeb39b9c0741 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } + virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult @@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern { rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Attempt to match and rewrite the IR root at the specified operation. LogicalResult matchAndRewrite(Operation *op, @@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern { : RewritePattern(std::forward<Args>(args)...), typeConverter(&typeConverter) {} + /// Given an array of value ranges, which are the inputs to a 1:N adaptor, + /// try to extract the single value of each range to construct a the inputs + /// for a 1:1 adaptor. + /// + /// This function produces a fatal error if at least one range has 0 or + /// more than 1 value: "pattern 'name' does not support 1:N conversion" + SmallVector<Value> + getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const; + protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; @@ -589,6 +607,8 @@ template <typename SourceOp> class OpConversionPattern : public ConversionPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>; OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} @@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern { auto sourceOp = cast<SourceOp>(op); rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + void rewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast<SourceOp>(op); return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern { rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConversionPattern::matchAndRewrite; @@ -656,11 +701,20 @@ class OpInterfaceConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { rewrite(cast<SourceOp>(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast<SourceOp>(op), operands, rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -668,6 +722,10 @@ class OpInterfaceConversionPattern : public ConversionPattern { ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { @@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern { rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } private: using ConversionPattern::matchAndRewrite; diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index a08764326a80b6..03be00328bda33 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -13,40 +13,6 @@ using namespace mlir; using namespace mlir::func; -//===----------------------------------------------------------------------===// -// Helper functions -//===----------------------------------------------------------------------===// - -/// If the given value can be decomposed with the type converter, decompose it. -/// Otherwise, return the given value. -// TODO: Value decomposition should happen automatically through a 1:N adaptor. -// This function will disappear when the 1:1 and 1:N drivers are merged. -static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, - Value value, - const TypeConverter *converter) { - // Try to convert the given value's type. If that fails, just return the - // given value. - SmallVector<Type> convertedTypes; - if (failed(converter->convertType(value.getType(), convertedTypes))) - return {value}; - if (convertedTypes.empty()) - return {}; - - // If the given value's type is already legal, just return the given value. - TypeRange convertedTypeRange(convertedTypes); - if (convertedTypeRange == TypeRange(value.getType())) - return {value}; - - // Try to materialize a target conversion. If the materialization did not - // produce values of the requested type, the materialization failed. Just - // return the given value in that case. - SmallVector<Value> result = converter->materializeTargetConversion( - builder, loc, convertedTypeRange, value); - if (result.empty()) - return {value}; - return result; -} - //===----------------------------------------------------------------------===// // DecomposeCallGraphTypesForFuncArgs //===----------------------------------------------------------------------===// @@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector<Value, 2> newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); return success(); } @@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CallOp op, OpAdaptor adaptor, + matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Create the operands list of the new `CallOp`. SmallVector<Value, 2> newOperands; - for (Value operand : adaptor.getOperands()) { - // TODO: We can directly take the values from the adaptor once this is a - // 1:N conversion pattern. - llvm::append_range(newOperands, - decomposeValue(rewriter, operand.getLoc(), operand, - getTypeConverter())); - } + for (ValueRange operand : adaptor.getOperands()) + llvm::append_range(newOperands, operand); // Create the new result types for the new `CallOp` and track the number of // replacement types for each original op result. diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index eb444d665ff260..d81f822f7d4b51 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { /// Hook for derived classes to implement combined matching and rewriting. LogicalResult - matchAndRewrite(CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the original function results. SmallVector<Type, 1> convertedResults; @@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> { // Substitute with the new result types from the corresponding FuncType // conversion. rewriter.replaceOpWithNewOp<CallOp>( - callOp, callOp.getCallee(), convertedResults, adaptor.getOperands()); + callOp, callOp.getCallee(), convertedResults, + getOneToOneAdaptorOperands(adaptor.getOperands())); return success(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 93a78056db1944..c0589044c26ecb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -16,20 +16,18 @@ using namespace mlir::scf; namespace { -// Unpacks the single unrealized_conversion_cast using the list of inputs -// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d) -static void unpackUnrealizedConversionCast(Value v, - SmallVectorImpl<Value> &unpacked) { - if (auto cast = - dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) { - if (cast.getInputs().size() != 1) { - // 1 : N type conversion. - unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); - return; - } - } - // 1 : 1 type conversion. - unpacked.push_back(v); +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { + SmallVector<Value> result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } // CRTP @@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> { public: using OpConversionPattern<SourceOp>::typeConverter; using OpConversionPattern<SourceOp>::OpConversionPattern; - using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor; + using OneToNOpAdaptor = + typename OpConversionPattern<SourceOp>::OneToNOpAdaptor; // // Derived classes should provide the following method which performs the // actual conversion. It should return std::nullopt upon conversion failure // and return the converted operation upon success. // - // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor, - // ConversionPatternRewriter &rewriter, - // TypeRange dstTypes) const; + // std::optional<SourceOp> convertSourceOp( + // SourceOp op, OneToNOpAdaptor adaptor, + // ConversionPatternRewriter &rewriter, + // TypeRange dstTypes) const; LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector<Type> dstTypes; SmallVector<unsigned> offsets; @@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> { return rewriter.notifyMatchFailure(op, "could not convert operation"); // Packs the return value. - SmallVector<Value> packedRets; + SmallVector<ValueRange> packedRets; for (unsigned i = 1, e = offsets.size(); i < e; i++) { unsigned start = offsets[i - 1], end = offsets[i]; unsigned len = end - start; ValueRange mappedValue = newOp->getResults().slice(start, len); - if (len != 1) { - // 1 : N type conversion. - Type origType = op.getResultTypes()[i - 1]; - Value mat = typeConverter->materializeSourceConversion( - rewriter, op.getLoc(), origType, mappedValue); - if (!mat) { - return rewriter.notifyMatchFailure( - op, "Failed to materialize 1:N type conversion"); - } - packedRets.push_back(mat); - } else { - // 1 : 1 type conversion. - packedRets.push_back(mappedValue.front()); - } + packedRets.push_back(mappedValue); } - rewriter.replaceOp(op, packedRets); + rewriter.replaceOpWithMultiple(op, packedRets); return success(); } }; @@ -105,7 +92,7 @@ class ConvertForOpTypes using Structural1ToNConversionPattern::Structural1ToNConversionPattern; // The callback required by CRTP. - std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor, + std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { // Create a empty new op and inline the regions from the old op. @@ -129,16 +116,13 @@ class ConvertForOpTypes if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter))) return std::nullopt; - // Unpacked the iteration arguments. - SmallVector<Value> flatArgs; - for (Value arg : adaptor.getInitArgs()) - unpackUnrealizedConversionCast(arg, flatArgs); - // We can not do clone as the number of result types after conversion // might be different. - ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(), - adaptor.getUpperBound(), - adaptor.getStep(), flatArgs); + ForOp newOp = rewriter.create<ForOp>( + op.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), + flattenValues(adaptor.getInitArgs())); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -160,12 +144,12 @@ class ConvertIfOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor, + std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes, - adaptor.getCondition(), true); + IfOp newOp = rewriter.create<IfOp>( + op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true); newOp->setAttrs(op->getAttrs()); // We do not need the empty blocks created by rewriter. @@ -189,15 +173,11 @@ class ConvertWhileOpTypes public: using Structural1ToNConversionPattern::Structural1ToNConversionPattern; - std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor, + std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - // Unpacked the iteration arguments. - SmallVector<Value> flatArgs; - for (Value arg : adaptor.getOperands()) - unpackUnrealizedConversionCast(arg, flatArgs); - - auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs); + auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, + flattenValues(adaptor.getOperands())); for (auto i : {0u, 1u}) { if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) @@ -218,13 +198,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, + matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield); + rewriter.replaceOpWithNewOp<scf::YieldOp>( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -235,13 +212,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> { public: using OpConversionPattern<ConditionOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, OpAdaptor adaptor, + matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> unpackedYield; - for (Value operand : adaptor.getOperands()) - unpackUnrealizedConversionCast(operand, unpackedYield); - - rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); }); + rewriter.modifyOpInPlace( + op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); }); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 25fca49cb0154a..9184224e7aef4b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -39,25 +39,18 @@ using namespace mlir::sparse_tensor; // Helper methods. //===----------------------------------------------------------------------===// -/// Flattens a list of operands that may contain sparse tensors. -static void flattenOperands(ValueRange operands, - SmallVectorImpl<Value> &flattened) { - // In case of - // sparse_tensor, c, sparse_tensor - // ==> - // memref ..., c, memref ... - for (auto operand : operands) { - if (getSparseTensorEncoding(operand.getType())) { - auto tuple = getTuple(operand); - // An unrealized_conversion_cast will be inserted by type converter to - // inter-mix the gap between 1:N conversion between sparse tensors and - // fields. In this case, take the operands in the cast and replace the - // sparse tensor output with the flattened type array. - flattened.append(tuple.getOperands().begin(), tuple.getOperands().end()); - } else { - flattened.push_back(operand); - } - } +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { + SmallVector<Value> result; + for (const auto &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); } /// Generates a load with proper `index` typing. @@ -567,12 +560,11 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> flattened; - flattenOperands(adaptor.getOperands(), flattened); // Create a return with the flattened value extracted from sparse tensors. - rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened); + rewriter.replaceOpWithNewOp<func::ReturnOp>( + op, flattenValues(adaptor.getOperands())); return success(); } }; @@ -583,7 +575,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> { // The default CallOp converter can not handle 1:N type conversion. using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // In case of: @@ -596,10 +588,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> { return failure(); // (1) Generates new call with flattened return value. - SmallVector<Value> flattened; - flattenOperands(adaptor.getOperands(), flattened); - auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(), - finalRetTy, flattened); + auto newCall = rewriter.create<func::CallOp>( + loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands())); // (2) Gather sparse tensor returns. SmallVector<SmallVector<Value>> packedResultVals; // Tracks the offset of current return value (of the original call) @@ -643,7 +633,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LvlOp op, OpAdaptor adaptor, + matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::optional<int64_t> lvl = op.getConstantLvlIndex(); RankedTensorType srcType = op.getSource().getType(); @@ -662,7 +652,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> { struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor, + matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -693,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> { // Since we do in-place sorting, the destinate tensor will have the same set // of memrefs as the source tensor. - rewriter.replaceOp(op, adaptor.getInputCoo()); + rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()}); return success(); } }; @@ -703,7 +693,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> { public: using OpConversionPattern<Op>::OpConversionPattern; LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, + matchAndRewrite(Op op, + typename OpConversionPattern<Op>::OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply lowers to specifer.get <field> operation. auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(), @@ -721,14 +712,14 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only rewrite identically annotated source/dest. auto encDst = getSparseTensorEncoding(op.getType()); auto encSrc = getSparseTensorEncoding(op.getSource().getType()); if (!encDst || encDst != encSrc) return failure(); - rewriter.replaceOp(op, adaptor.getOperands()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -737,10 +728,10 @@ class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, + matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Simply fold the operation. - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } }; @@ -756,7 +747,7 @@ class SparseTensorAllocConverter enableBufferInitialization(enableInit) {} LogicalResult - matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const auto resType = getSparseTensorType(op); if (!resType.hasEncoding()) @@ -791,7 +782,8 @@ class SparseTensorAllocConverter } // Level size equals to dimension size since lvl2dim map is an identity map. SmallVector<Value> lvlSizesValues; - createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(), + createDimSizes(rewriter, loc, resType, + flattenValues(adaptor.getDynamicSizes()), /*dimSizesValues=*/lvlSizesValues); // Construct allocation for each field. @@ -861,7 +853,7 @@ class SparseTensorDeallocConverter createDeallocs(createDeallocs) {} LogicalResult - matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, + matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto enc = getSparseTensorEncoding(op.getTensor().getType()); if (!enc) @@ -892,7 +884,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(LoadOp op, OpAdaptor adaptor, + matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Prepare descriptor. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), @@ -911,7 +903,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExpandOp op, OpAdaptor adaptor, + matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!getSparseTensorEncoding(op.getTensor().getType())) return failure(); @@ -963,16 +955,16 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CompressOp op, OpAdaptor adaptor, + matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); SmallVector<Value> fields; auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields, op.getTensor().getType()); - Value values = adaptor.getValues(); - Value filled = adaptor.getFilled(); - Value added = adaptor.getAdded(); - Value count = adaptor.getCount(); + Value values = getSingleValue(adaptor.getValues()); + Value filled = getSingleValue(adaptor.getFilled()); + Value added = getSingleValue(adaptor.getAdded()); + Value count = getSingleValue(adaptor.getCount()); const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); @@ -1005,7 +997,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> { SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end()); SmallVector<Type> flatSpTensorTps = llvm::to_vector( llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); })); - params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end()); + SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords()); + params.append(flatLvlCoords.begin(), flatLvlCoords.end()); params.push_back(crd); params.push_back(value); SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps, @@ -1033,9 +1026,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stt = getSparseTensorType(adaptor.getDest()); + auto stt = getSparseTensorType(op.getDest()); if (!stt.hasEncoding()) return failure(); assert(stt.isIdentity() && "Run reinterpret-map before conversion."); @@ -1045,8 +1038,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> { getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType()); TypeRange flatSpTensorTps = desc.getFields().getTypes(); SmallVector<Value> params = llvm::to_vector(desc.getFields()); - params.append(adaptor.getIndices().begin(), adaptor.getIndices().end()); - params.push_back(adaptor.getScalar()); + SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices()); + params.append(flatIndices.begin(), flatIndices.end()); + params.push_back(getSingleValue(adaptor.getScalar())); SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps, params, /*genCall=*/true); SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc); @@ -1062,7 +1056,7 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> { using OpAdaptor = typename ToPositionsOp::Adaptor; using OpConversionPattern<ToPositionsOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, + matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested position access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1085,7 +1079,7 @@ class SparseToCoordinatesConverter using OpAdaptor = typename ToCoordinatesOp::Adaptor; using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1111,7 +1105,7 @@ class SparseToCoordinatesBufferConverter using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor, + matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested coordinates access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1133,7 +1127,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> { using OpAdaptor = typename ToValuesOp::Adaptor; using OpConversionPattern<ToValuesOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, + matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace the requested values access with corresponding field. // The view is restricted to the actual size to ensure clients @@ -1153,7 +1147,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConvertOp op, OpAdaptor adaptor, + matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); SparseTensorEncodingAttr encSrc = @@ -1173,7 +1167,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> { Type srcElemTp = op.getSource().getType().getElementType(); // Fold the trivial cases. if (retElemTp == srcElemTp && encDst == encSrc) { - rewriter.replaceOp(op, adaptor.getSource()); + rewriter.replaceOpWithMultiple(op, {adaptor.getSource()}); return success(); } // @@ -1239,7 +1233,7 @@ class SparseExtractSliceConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, + matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); @@ -1296,7 +1290,7 @@ class SparseNumberOfEntriesConverter public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, + matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values. // FIXME: the nse value computed in this way might be wrong when there is @@ -1430,7 +1424,7 @@ struct SparseDisassembleOpConverter : OpConversionPattern(typeConverter, context) {} LogicalResult - matchAndRewrite(DisassembleOp op, OpAdaptor adaptor, + matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(), op.getTensor().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h index 89858546e37e1b..869c7864d75354 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h @@ -228,11 +228,6 @@ class MutSparseTensorDescriptor } }; -/// Returns the "tuple" value of the adapted tensor. -inline UnrealizedConversionCastOp getTuple(Value tensor) { - return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp()); -} - /// Packs the given values as a "tuple" value. inline Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -246,16 +241,15 @@ inline Value genTuple(OpBuilder &builder, Location loc, } inline SparseTensorDescriptor -getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs()); +getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) { + return SparseTensorDescriptor(SparseTensorType(type), adaptorValues); } inline MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields, +getMutDescriptorFromTensorTuple(ValueRange adaptorValues, + SmallVectorImpl<Value> &fields, RankedTensorType type) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); + fields.assign(adaptorValues.begin(), adaptorValues.end()); return MutSparseTensorDescriptor(SparseTensorType(type), fields); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 5b2cfd370900a8..6906dd6ee2ea7f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -67,10 +67,6 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { // ConversionValueMapping //===----------------------------------------------------------------------===// -/// A list of replacement SSA values. Optimized for the common case of a single -/// SSA value. -using ReplacementValues = SmallVector<Value, 1>; - namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. @@ -780,7 +776,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped); + SmallVector<SmallVector<Value>> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -814,13 +810,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Materializations //===--------------------------------------------------------------------===// - /// Build an unresolved materialization operation given an output type and set - /// of input operands. - Value buildUnresolvedMaterialization(MaterializationKind kind, - OpBuilder::InsertPoint ip, Location loc, - ValueRange inputs, Type outputType, - Type originalType, - const TypeConverter *converter); + /// Build an unresolved materialization operation given a range of output + /// types and a list of input operands. Returns the inputs if they their + /// types match the output types. + /// + /// If a cast op was built, it can optionally be returned with the `castOp` + /// output argument. + ValueRange buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + ValueRange inputs, TypeRange outputTypes, Type originalType, + const TypeConverter *converter, + UnrealizedConversionCastOp *castOp = nullptr); + Value buildUnresolvedMaterialization( + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + ValueRange inputs, Type outputType, Type originalType, + const TypeConverter *converter, + UnrealizedConversionCastOp *castOp = nullptr) { + return buildUnresolvedMaterialization(kind, ip, loc, inputs, + TypeRange(outputType), originalType, + converter, castOp) + .front(); + } /// Build an N:1 materialization for the given original value that was /// replaced with the given replacement values. @@ -838,6 +848,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange replacements, Value originalValue, const TypeConverter *converter); + /// Unpack an N:1 materialization and return the inputs of the + /// materialization. This function unpacks only those materializations that + /// were built with `insertNTo1Materialization`. + /// + /// This is a workaround around incomplete 1:N support in the dialect + /// conversion driver. It allows us to write 1:N conversion patterns while + /// 1:N support is still missing in the conversion value mapping. This + /// function will be deleted when full 1:N support has been added. + SmallVector<Value> unpackNTo1Materialization(Value value); + //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -847,7 +867,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { OpBuilder::InsertPoint previous) override; /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues); + void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -940,6 +960,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> unresolvedMaterializations; + /// A set of all N:1 materializations that were added to work around + /// incomplete 1:N support in the dialect conversion driver. + DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations; + /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1076,6 +1100,7 @@ void UnresolvedMaterializationRewrite::rollback() { rewriterImpl.mapping.erase(input); } rewriterImpl.unresolvedMaterializations.erase(getOperation()); + rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); } @@ -1119,7 +1144,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped) { + SmallVector<SmallVector<Value>> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1131,7 +1156,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back({mapping.lookupOrDefault(operand)}); continue; } @@ -1145,15 +1170,32 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return failure(); } + // If a type is converted to 0 types, there is nothing to do. + if (legalTypes.empty()) { + remapped.push_back({}); + continue; + } + if (legalTypes.size() != 1) { - // TODO: Parts of the dialect conversion infrastructure do not support - // 1->N type conversions yet. Therefore, if a type is converted to 0 or - // multiple types, the only thing that we can do for now is passing - // through the most recently mapped value. Fixing this requires - // improvements to the `ConversionValueMapping` (to be able to store 1:N - // mappings) and to the `ConversionPattern` adaptor handling (to be able - // to pass multiple remapped values for a single operand to the adaptor). - remapped.push_back(mapping.lookupOrDefault(operand)); + // TODO: This is a 1:N conversion. The conversion value mapping cannot + // such conversions yet. It stored the result of an argument + // materialization (i.e., a conversion back into a single SSA value) + // instead. Unpack such "workaround" materializations and hand the + // original replacement values to the adaptor. + Value repl = mapping.lookupOrDefault(operand); + SmallVector<Value> unpacked = unpackNTo1Materialization(repl); + if (TypeRange(unpacked) == legalTypes) { + remapped.push_back(unpacked); + continue; + } + + // Insert a target materialization if the current pattern expects + // different legalized types. + ValueRange targetMat = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*inputs=*/repl, /*outputType=*/legalTypes, + /*originalType=*/origType, currentTypeConverter); + remapped.push_back(targetMat); continue; } @@ -1165,7 +1207,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( if (newOperand.getType() != desiredType) { // If the looked up value's type does not have the desired type, it means // that the value was replaced with a value of different type and no - // source materialization was created yet. + // target materialization was created yet. Value castValue = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(newOperand), operandLoc, @@ -1174,7 +1216,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( mapping.map(newOperand, castValue); newOperand = castValue; } - remapped.push_back(newOperand); + remapped.push_back({newOperand}); } return success(); } @@ -1329,26 +1371,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// Build an unresolved materialization operation given an output type and set /// of input operands. -Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( +ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter) { + ValueRange inputs, TypeRange outputTypes, Type originalType, + const TypeConverter *converter, UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); // Avoid materializing an unnecessary cast. - if (inputs.size() == 1 && inputs.front().getType() == outputType) - return inputs.front(); + if (TypeRange(inputs) == outputTypes) + return inputs; // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. - OpBuilder builder(outputType.getContext()); + OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = - builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); + builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); + if (castOp) + *castOp = convertOp; appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType); - return convertOp.getResult(0); + return convertOp.getResults(); } void ConversionPatternRewriterImpl::insertNTo1Materialization( @@ -1356,10 +1400,13 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( Value originalValue, const TypeConverter *converter) { // Insert argument materialization back to the original type. Type originalType = originalValue.getType(); - Value argMat = - buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter); + UnrealizedConversionCastOp argCastOp; + Value argMat = buildUnresolvedMaterialization( + MaterializationKind::Argument, ip, loc, + /*inputs=*/replacements, originalType, + /*originalType=*/Type(), converter, &argCastOp); + if (argCastOp) + nTo1TempMaterializations.insert(argCastOp); mapping.map(originalValue, argMat); // Insert target materialization to the legalized type. @@ -1376,14 +1423,36 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( legalOutputType = replacements[0].getType(); } if (legalOutputType && legalOutputType != originalType) { + UnrealizedConversionCastOp targetCastOp; Value targetMat = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(argMat), loc, /*inputs=*/argMat, /*outputType=*/legalOutputType, - /*originalType=*/originalType, converter); + /*originalType=*/originalType, converter, &targetCastOp); + if (targetCastOp) + nTo1TempMaterializations.insert(targetCastOp); mapping.map(argMat, targetMat); } } +SmallVector<Value> +ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { + // Unpack unrealized_conversion_cast ops that were inserted as a N:1 + // workaround. + auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>(); + if (!castOp) + return {value}; + if (!nTo1TempMaterializations.contains(castOp)) + return {value}; + assert(castOp->getNumResults() == 1 && "expected single result"); + + SmallVector<Value> result; + for (Value v : castOp.getOperands()) { + // Keep unpacking if possible. + llvm::append_range(result, unpackNTo1Materialization(v)); + } + return result; +} + //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1408,7 +1477,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( } void ConversionPatternRewriterImpl::notifyOpReplaced( - Operation *op, ArrayRef<ReplacementValues> newValues) { + Operation *op, ArrayRef<ValueRange> newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); @@ -1420,8 +1489,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( isUnresolvedMaterialization = true; // Create mappings for each of the new result values. - for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) { - ReplacementValues repl = n; + for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) { if (repl.empty()) { // This result was dropped and no replacement value was provided. if (isUnresolvedMaterialization) { @@ -1436,7 +1504,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( result.getLoc(), /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); - repl.push_back(sourceMat); + mapping.map(result, sourceMat); + continue; } else { // Make sure that the user does not mess with unresolved materializations // that were inserted by the conversion driver. We keep track of these @@ -1538,10 +1607,9 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector<ReplacementValues> newVals(newValues.size()); - for (auto [index, val] : llvm::enumerate(newValues)) - if (val) - newVals[index].push_back(val); + SmallVector<ValueRange> newVals; + for (int i = 0; i < newValues.size(); ++i) + newVals.push_back(newValues.slice(i, 1)); impl->notifyOpReplaced(op, newVals); } @@ -1553,10 +1621,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector<ReplacementValues> newVals(newValues.size(), {}); - for (auto [index, val] : llvm::enumerate(newValues)) - llvm::append_range(newVals[index], val); - impl->notifyOpReplaced(op, newVals); + impl->notifyOpReplaced(op, newValues); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1564,7 +1629,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); - SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {}); + SmallVector<ValueRange> nullRepls(op->getNumResults(), {}); impl->notifyOpReplaced(op, nullRepls); } @@ -1615,11 +1680,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<Value> remappedValues; + SmallVector<SmallVector<Value>> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; - return remappedValues.front(); + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); } LogicalResult @@ -1627,8 +1693,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + SmallVector<SmallVector<Value>> remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1722,6 +1795,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// +SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( + ArrayRef<ValueRange> operands) const { + SmallVector<Value> oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ValueRange operand : operands) { + if (operand.size() != 1) + llvm::report_fatal_error("pattern '" + getDebugName() + + "' does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } + return oneToOneOperands; +} + LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { @@ -1733,12 +1819,14 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<Value, 4> operands; + SmallVector<SmallVector<Value>> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), remapped))) { return failure(); } - return matchAndRewrite(op, operands, dialectRewriter); + SmallVector<ValueRange> remappedAsRange = llvm::map_to_vector( + remapped, [](const auto &v) -> ValueRange { return v; }); + return matchAndRewrite(op, remappedAsRange, dialectRewriter); } //===----------------------------------------------------------------------===// @@ -1965,19 +2053,19 @@ OperationLegalizer::legalizeWithFold(Operation *op, }); // Try to fold the operation. - SmallVector<Value, 2> replacementValues; + SmallVector<Value, 2> ValueRange; rewriter.setInsertionPoint(op); - if (failed(rewriter.tryFold(op, replacementValues))) { + if (failed(rewriter.tryFold(op, ValueRange))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); return failure(); } // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. - if (replacementValues.empty()) + if (ValueRange.empty()) return legalize(op, rewriter); // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); + rewriter.replaceOp(op, ValueRange); // Recursively legalize any new constant operations. for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); @@ -2482,45 +2570,52 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, assert(!op.use_empty() && "expected that dead materializations have already been DCE'd"); Operation::operand_range inputOperands = op.getOperands(); - Type outputType = op.getResultTypes()[0]; // Try to materialize the conversion. if (const TypeConverter *converter = rewrite->getConverter()) { rewriter.setInsertionPoint(op); - Value newMaterialization; + SmallVector<Value> newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: + case MaterializationKind::Argument: { // Try to materialize an argument conversion. - newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), outputType, inputOperands); - if (newMaterialization) + assert(op->getNumResults() == 1 && "expected single result"); + Value argMat = converter->materializeArgumentConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (argMat) { + newMaterialization.push_back(argMat); break; + } + } // If an argument materialization failed, fallback to trying a target // materialization. [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( - rewriter, op->getLoc(), outputType, inputOperands, + rewriter, op->getLoc(), op.getResultTypes(), inputOperands, rewrite->getOriginalType()); break; case MaterializationKind::Source: - newMaterialization = converter->materializeSourceConversion( - rewriter, op->getLoc(), outputType, inputOperands); + assert(op->getNumResults() == 1 && "expected single result"); + Value sourceMat = converter->materializeSourceConversion( + rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); + if (sourceMat) + newMaterialization.push_back(sourceMat); break; } - if (newMaterialization) { - assert(newMaterialization.getType() == outputType && + if (!newMaterialization.empty()) { + assert(TypeRange(newMaterialization) == op.getResultTypes() && "materialization callback produced value of incorrect type"); rewriter.replaceOp(op, newMaterialization); return success(); } } - InFlightDiagnostic diag = - op->emitError() << "failed to legalize unresolved materialization " - "from (" - << inputOperands.getTypes() << ") to (" << outputType - << ") that remained live after conversion"; + InFlightDiagnostic diag = op->emitError() + << "failed to legalize unresolved materialization " + "from (" + << inputOperands.getTypes() << ") to (" + << op.getResultTypes() + << ") that remained live after conversion"; diag.attachNote(op->getUsers().begin()->getLoc()) << "see existing live user here: " << *op->getUsers().begin(); return failure(); diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index b8fad63eb4de67..4e641317ac2f3d 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -9,10 +9,7 @@ // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32 // CHECK-12N-LABEL: func @identity( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl // CHECK-LABEL: func @mixed_recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> -// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1> -// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2> -// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>> -// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> -// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<> -// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1> -// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1 -// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>> -// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2> -// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2 -// CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2 // CHECK-12N-LABEL: func @mixed_recursive_decomposition( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { @@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32> // CHECK-LABEL: func @caller( // CHECK-SAME: %[[ARG0:.*]]: i1, // CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32) -// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32> -// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1 -// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32 -// CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[ARG0:.*]]: i1, // CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { @@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup // CHECK-SAME: %[[I4:.*]]: i4, // CHECK-SAME: %[[I5:.*]]: i5, // CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { -// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5> -// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4 -// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5 -// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5> -// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4 -// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5 -// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 // CHECK-12N-LABEL: func @caller( // CHECK-12N-SAME: %[[I1:.*]]: i1, // CHECK-12N-SAME: %[[I2:.*]]: i2, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits