llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> The dialect conversion maintains a set of unresolved materializations (`UnrealizedConversionCastOp`). Turn that set into a `DenseMap` that maps from ops to `UnresolvedMaterializationRewrite *`. This improves efficiency a bit, because an iteration over `ConversionPatternRewriterImpl::rewrites` can be avoided. Also delete some dead code. --- Full diff: https://github.com/llvm/llvm-project/pull/108359.diff 1 Files Affected: - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-40) ``````````diff diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index b58a95c3baf70a..ed15b571f01883 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, - MaterializationKind kind = MaterializationKind::Target) - : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind) {} + MaterializationKind kind = MaterializationKind::Target); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } -/// Find the single rewrite object of the specified type and block among the -/// given rewrites. In debug mode, asserts that there is mo more than one such -/// object. Return "nullptr" if no object was found. -template <typename RewriteTy, typename R> -static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { - RewriteTy *result = nullptr; - for (auto &rewrite : rewrites) { - auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); - if (rewriteTy && rewriteTy->getBlock() == block) { -#ifndef NDEBUG - assert(!result && "expected single matching rewrite"); - result = rewriteTy; -#else - return rewriteTy; -#endif // NDEBUG - } - } - return result; -} - //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasErased(void *ptr) const { return erased.contains(ptr); } - bool wasErased(OperationRewrite *rewrite) const { - return wasErased(rewrite->getOperation()); - } - void notifyOperationErased(Operation *op) override { erased.insert(op); } void notifyBlockErased(Block *block) override { erased.insert(block); } @@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector<Operation *> replacedOps; - /// A set of all unresolved materializations. - DenseSet<Operation *> unresolvedMaterializations; + /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) + /// to the corresponding rewrite objects. + DenseMap<Operation *, UnresolvedMaterializationRewrite *> + unresolvedMaterializations; /// The current type converter, or nullptr if no type converter is currently /// active. @@ -1058,6 +1034,14 @@ void CreateOperationRewrite::rollback() { op->erase(); } +UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( + ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, + const TypeConverter *converter, MaterializationKind kind) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind) { + rewriterImpl.unresolvedMaterializations[op] = this; +} + void UnresolvedMaterializationRewrite::rollback() { if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) @@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); - unresolvedMaterializations.insert(convertOp); appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind); return convertOp.getResult(0); } @@ -2499,15 +2482,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { // Gather all unresolved materializations. SmallVector<UnrealizedConversionCastOp> allCastOps; - DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap; - for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) { - auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get()); - if (!mat) - continue; - if (rewriterImpl.eraseRewriter.wasErased(mat)) + const DenseMap<Operation *, UnresolvedMaterializationRewrite *> + &materializations = rewriterImpl.unresolvedMaterializations; + for (auto it : materializations) { + if (rewriterImpl.eraseRewriter.wasErased(it.first)) continue; - allCastOps.push_back(mat->getOperation()); - rewriteMap[mat->getOperation()] = mat; + allCastOps.push_back(cast<UnrealizedConversionCastOp>(it.first)); } // Reconcile all UnrealizedConversionCastOps that were inserted by the @@ -2520,8 +2500,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); for (UnrealizedConversionCastOp castOp : remainingCastOps) { - auto it = rewriteMap.find(castOp.getOperation()); - assert(it != rewriteMap.end() && "inconsistent state"); + auto it = materializations.find(castOp.getOperation()); + assert(it != materializations.end() && "inconsistent state"); if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) return failure(); } `````````` </details> https://github.com/llvm/llvm-project/pull/108359 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits