Author: Matthias Springer Date: 2024-11-21T10:40:33+09:00 New Revision: fe0ac007ca9e253e79d2dc0e95ce166efd585a5b
URL: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b DIFF: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b.diff LOG: Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (…" This reverts commit aa65473c9ddcf3cbb80e63c38af842d05346374b. Added: Modified: mlir/lib/Transforms/Utils/DialectConversion.cpp Removed: ################################################################################ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 03d483f73f255e..42fe5b925654a1 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -75,10 +75,6 @@ namespace { /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { - /// Return "true" if an SSA value is mapped to the given value. May return - /// false positives. - bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped value with the desired type in the /// mapping. /// @@ -103,18 +99,22 @@ struct ConversionValueMapping { assert(it != oldVal && "inserting cyclic mapping"); }); mapping.map(oldVal, newVal); - mappedTo.insert(newVal); } /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } + /// Returns the inverse raw value mapping (without recursive query support). + DenseMap<Value, SmallVector<Value>> getInverse() const { + DenseMap<Value, SmallVector<Value>> inverse; + for (auto &it : mapping.getValueMap()) + inverse[it.second].push_back(it.first); + return inverse; + } + private: /// Current value mappings. IRMapping mapping; - - /// All SSA values that are mapped to. May contain false positives. - DenseSet<Value> mappedTo; }; } // namespace @@ -434,9 +434,10 @@ class MoveBlockRewrite : public BlockRewrite { class BlockTypeConversionRewrite : public BlockRewrite { public: BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, Block *origBlock) + Block *block, Block *origBlock, + const TypeConverter *converter) : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock) {} + origBlock(origBlock), converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; @@ -444,6 +445,8 @@ class BlockTypeConversionRewrite : public BlockRewrite { Block *getOrigBlock() const { return origBlock; } + const TypeConverter *getConverter() const { return converter; } + void commit(RewriterBase &rewriter) override; void rollback() override; @@ -451,6 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite { private: /// The original block that was requested to have its signature converted. Block *origBlock; + + /// The type converter used to convert the arguments. + const TypeConverter *converter; }; /// Replacing a block argument. This rewrite is not immediately reflected in the @@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite { class ReplaceBlockArgRewrite : public BlockRewrite { public: ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg, - const TypeConverter *converter) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), - converter(converter) {} + Block *block, BlockArgument arg) + : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ReplaceBlockArg; @@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite { private: BlockArgument arg; - - /// The current type converter when the block argument was replaced. - const TypeConverter *converter; }; /// An operation rewrite. @@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite { void cleanup(RewriterBase &rewriter) override; + const TypeConverter *getConverter() const { return converter; } + private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. @@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange replacements, Value originalValue, const TypeConverter *converter); - /// Find a replacement value for the given SSA value in the conversion value - /// mapping. The replacement value must have the same type as the given SSA - /// value. If there is no replacement value with the correct type, find the - /// latest replacement value (regardless of the type) and build a source - /// materialization. - Value findOrBuildReplacementValue(Value value, - const TypeConverter *converter); - //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() { } void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); if (!repl) return; @@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector<Value> replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.findOrBuildReplacementValue(result, converter); + return rewriterImpl.mapping.lookupOrNull(result, result.getType()); }); // Notify the listener that the operation is about to be replaced. @@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() { void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. IRRewriter rewriter(context, config.listener); - // Note: New rewrites may be added during the "commit" phase and the - // `rewrites` vector may reallocate. - for (size_t i = 0; i < rewrites.size(); ++i) - rewrites[i]->commit(rewriter); + for (auto &rewrite : rewrites) + rewrite->commit(rewriter); // Clean up all rewrites. for (auto &rewrite : rewrites) @@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); mapping.map(origArg, repl); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + appendRewrite<ReplaceBlockArgRewrite>(block, origArg); continue; } @@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, repl); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + appendRewrite<ReplaceBlockArgRewrite>(block, origArg); continue; } @@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( insertNTo1Materialization( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + appendRewrite<ReplaceBlockArgRewrite>(block, origArg); } - appendRewrite<BlockTypeConversionRewrite>(newBlock, block); + appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization( } } -Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( - Value value, const TypeConverter *converter) { - // Find a replacement value with the same type. - Value repl = mapping.lookupOrNull(value, value.getType()); - if (repl) - return repl; - - // Check if the value is dead. No replacement value is needed in that case. - // This is an approximate check that may have false negatives but does not - // require computing and traversing an inverse mapping. (We may end up - // building source materializations that are never used and that fold away.) - if (llvm::all_of(value.getUsers(), - [&](Operation *op) { return replacedOps.contains(op); }) && - !mapping.isMappedTo(value)) - return Value(); - - // No replacement value was found. Get the latest replacement value - // (regardless of the type) and build a source materialization to the - // original type. - repl = mapping.lookupOrNull(value); - if (!repl) { - // No replacement value is registered in the mapping. This means that the - // value is dropped and no longer needed. (If the value were still needed, - // a source materialization producing a replacement value "out of thin air" - // would have already been created during `replaceOp` or - // `applySignatureConversion`.) - return Value(); - } - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), - /*inputs=*/repl, /*outputType=*/value.getType(), - /*originalType=*/Type(), converter); - return castValue; -} - //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, << "'(in region of '" << parentOp->getName() << "'(" << from.getOwner()->getParentOp() << ")\n"; }); - impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, - impl->currentTypeConverter); + impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } @@ -2460,6 +2417,10 @@ struct OperationConverter { /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); + /// This method is called after the conversion process to legalize any + /// remaining artifacts and complete the conversion. + void finalize(ConversionPatternRewriter &rewriter); + /// Dialect conversion configuration. ConversionConfig config; @@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { if (failed(convert(rewriter, op))) return rewriterImpl.undoRewrites(), failure(); + // Now that all of the operations have been converted, finalize the conversion + // process to ensure any lingering conversion artifacts are cleaned up and + // legalized. + finalize(rewriter); + // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); @@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { return success(); } +/// Finds a user of the given value, or of any other value that the given value +/// replaced, that was not replaced in the conversion process. +static Operation *findLiveUserOfReplaced( + Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, + const DenseMap<Value, SmallVector<Value>> &inverseMapping) { + SmallVector<Value> worklist = {initialValue}; + while (!worklist.empty()) { + Value value = worklist.pop_back_val(); + + // Walk the users of this value to see if there are any live users that + // weren't replaced during conversion. + auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt != value.user_end()) + return *liveUserIt; + auto mapIt = inverseMapping.find(value); + if (mapIt != inverseMapping.end()) + worklist.append(mapIt->second); + } + return nullptr; +} + +/// Helper function that returns the replaced values and the type converter if +/// the given rewrite object is an "operation replacement" or a "block type +/// conversion" (which corresponds to a "block replacement"). Otherwise, return +/// an empty ValueRange and a null type converter pointer. +static std::pair<ValueRange, const TypeConverter *> +getReplacedValues(IRRewrite *rewrite) { + if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite)) + return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()}; + if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite)) + return {blockRewrite->getOrigBlock()->getArguments(), + blockRewrite->getConverter()}; + return {}; +} + +void OperationConverter::finalize(ConversionPatternRewriter &rewriter) { + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + DenseMap<Value, SmallVector<Value>> inverseMapping = + rewriterImpl.mapping.getInverse(); + + // Process requested value replacements. + for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) { + ValueRange replacedValues; + const TypeConverter *converter; + std::tie(replacedValues, converter) = + getReplacedValues(rewriterImpl.rewrites[i].get()); + for (Value originalValue : replacedValues) { + // If the type of this value changed and the value is still live, we need + // to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(originalValue, + originalValue.getType())) + continue; + Operation *liveUser = + findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping); + if (!liveUser) + continue; + + // Legalize this value replacement. + Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue); + assert(newValue && "replacement value not found"); + Value castValue = rewriterImpl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(newValue), + originalValue.getLoc(), + /*inputs=*/newValue, /*outputType=*/originalValue.getType(), + /*originalType=*/Type(), converter); + rewriterImpl.mapping.map(originalValue, castValue); + inverseMapping[castValue].push_back(originalValue); + llvm::erase(inverseMapping[newValue], originalValue); + } + } +} + //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits