https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/108381
PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterated over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed. >From 1f215ac7861a76f653c9911a31bf484a5fd6dac4 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Thu, 12 Sep 2024 14:49:23 +0200 Subject: [PATCH] [mlir][Transforms] Dialect conversion: Unify materialization of value replacements PR #106760 aligned the handling of dropped block arguments and dropped op results. The two helper functions that insert source materializations for uses of replaced block arguments / op results that survived the conversion are now almost identical (`legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes`). This PR merges the two functions and moves the implementation directly into `finalize`. This PR simplifies the code base and improves the efficiency a bit: previously, `finalize` iterates over `ConversionPatternRewriterImpl::rewrites` twice. Now, only one iteration is needed. --- .../Transforms/Utils/DialectConversion.cpp | 134 ++++++------------ .../VectorToSPIRV/vector-to-spirv.mlir | 4 +- 2 files changed, 44 insertions(+), 94 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index ed15b571f01883..0556b4ab833c30 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2336,17 +2336,6 @@ struct OperationConverter { /// remaining artifacts and complete the conversion. LogicalResult finalize(ConversionPatternRewriter &rewriter); - /// Legalize the types of converted block arguments. - LogicalResult - legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl); - - /// Legalize the types of converted op results. - LogicalResult legalizeConvertedOpResultTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap<Value, SmallVector<Value>> &inverseMapping); - /// Dialect conversion configuration. ConversionConfig config; @@ -2510,19 +2499,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { return success(); } -LogicalResult -OperationConverter::finalize(ConversionPatternRewriter &rewriter) { - ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) - return failure(); - DenseMap<Value, SmallVector<Value>> inverseMapping = - rewriterImpl.mapping.getInverse(); - if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl, - inverseMapping))) - return failure(); - return success(); -} - /// Finds a user of the given value, or of any other value that the given value /// replaced, that was not replaced in the conversion process. static Operation *findLiveUserOfReplaced( @@ -2546,87 +2522,61 @@ static Operation *findLiveUserOfReplaced( return nullptr; } -LogicalResult OperationConverter::legalizeConvertedOpResultTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - DenseMap<Value, SmallVector<Value>> &inverseMapping) { - // Process requested operation replacements. - for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { - auto *opReplacement = - dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get()); - if (!opReplacement) - continue; - Operation *op = opReplacement->getOperation(); - for (OpResult result : op->getResults()) { - // If the type of this op result changed and the result is still live, - // we need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(result, result.getType())) +/// Helper function that returns the replaced values and the type converter if +/// the given rewrite object is an "operation replacement" or a "block type +/// conversion" (which corresponds to a "block replacement"). Otherwise, return +/// an empty ValueRange and a null type converter pointer. +static std::pair<ValueRange, const TypeConverter *> +getReplacedValues(IRRewrite *rewrite) { + if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite)) + return std::make_pair(opRewrite->getOperation()->getResults(), + opRewrite->getConverter()); + if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite)) + return std::make_pair(blockRewrite->getOrigBlock()->getArguments(), + blockRewrite->getConverter()); + return std::make_pair(ValueRange(), nullptr); +} + +LogicalResult +OperationConverter::finalize(ConversionPatternRewriter &rewriter) { + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + DenseMap<Value, SmallVector<Value>> inverseMapping = + rewriterImpl.mapping.getInverse(); + + // Process requested value replacements. + for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) { + ValueRange replacedValues; + const TypeConverter *converter; + std::tie(replacedValues, converter) = + getReplacedValues(rewriterImpl.rewrites[i].get()); + for (Value originalValue : replacedValues) { + // If the type of this value changed and the value is still live, we need + // to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(originalValue, + originalValue.getType())) continue; Operation *liveUser = - findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); + findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping); if (!liveUser) continue; - // Legalize this result. - Value newValue = rewriterImpl.mapping.lookupOrNull(result); + // Legalize this value replacement. + Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); assert(newValue && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(result), op->getLoc(), - /*inputs=*/newValue, /*outputType=*/result.getType(), - opReplacement->getConverter()); - rewriterImpl.mapping.map(result, castValue); - inverseMapping[castValue].push_back(result); - llvm::erase(inverseMapping[newValue], result); + MaterializationKind::Source, computeInsertPoint(newValue), + originalValue.getLoc(), + /*inputs=*/newValue, /*outputType=*/originalValue.getType(), + converter); + rewriterImpl.mapping.map(originalValue, castValue); + inverseMapping[castValue].push_back(originalValue); + llvm::erase(inverseMapping[newValue], originalValue); } } return success(); } -LogicalResult OperationConverter::legalizeConvertedArgumentTypes( - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl) { - // Functor used to check if all users of a value will be dead after - // conversion. - // TODO: This should probably query the inverse mapping, same as in - // `legalizeConvertedOpResultTypes`. - auto findLiveUser = [&](Value val) { - auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - return liveUserIt == val.user_end() ? nullptr : *liveUserIt; - }; - // Note: `rewrites` may be reallocated as the loop is running. - for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size()); - ++i) { - auto &rewrite = rewriterImpl.rewrites[i]; - if (auto *blockTypeConversionRewrite = - dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) { - // Process the remapping for each of the original arguments. - for (Value origArg : - blockTypeConversionRewrite->getOrigBlock()->getArguments()) { - // If the type of this argument changed and the argument is still live, - // we need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - continue; - Operation *liveUser = findLiveUser(origArg); - if (!liveUser) - continue; - - Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg); - assert(replacementValue && "replacement value not found"); - Value repl = rewriterImpl.buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(replacementValue), - origArg.getLoc(), /*inputs=*/replacementValue, - /*outputType=*/origArg.getType(), - blockTypeConversionRewrite->getConverter()); - rewriterImpl.mapping.map(origArg, repl); - } - } - } - return success(); -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index d8570bdaf4247f..25ec5d0159bd5d 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -558,8 +558,8 @@ func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) { // CHECK-LABEL: func @deinterleave_scalar // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>) -// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> -// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32> +// CHECK-DAG: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32> +// CHECK-DAG: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32> // CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32> // CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32> // CHECK: return %[[CAST0]], %[[CAST1]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits