https://github.com/matthias-springer created 
https://github.com/llvm/llvm-project/pull/145319

This commit adds extra state to `ConversionPatternRewriterImpl` to store all 
modified / newly-created operations and moved / newly-created blocks in 
separate lists on a per-pattern basis.

This is in preparation of the One-Shot Dialect Conversion refactoring: the new 
driver will no longer maintain a list of all IR rewrites, so information about 
newly-created operations (which is needed to trigger recursive legalization) 
must be retained in a different data structure.

This commit is also expected to improve the performance of the existing driver. 
The previous implementation iterated over all new IR modifications and then 
filtered them by type. It also required an additional pointer indirection 
(through `std::unique_ptr<IRRewrite>`) to retrieve the operation/block pointer.

Depends on #145308.


>From 0fb4c01f02f042dda35a9ad6874d38c22e06c9bc Mon Sep 17 00:00:00 2001
From: Matthias Springer <m...@m-sp.org>
Date: Mon, 23 Jun 2025 12:36:02 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Store per-pattern IR modifications in
 separate state

---
 .../Transforms/Utils/DialectConversion.cpp    | 139 ++++++++++--------
 1 file changed, 75 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7cfe7250d02c3..955a106c21941 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
+  /// A set of operations that were created by the current pattern.
+  SetVector<Operation *> patternNewOps;
+
+  /// A set of operations that were modified by the current pattern.
+  SetVector<Operation *> patternModifiedOps;
+
+  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
+  /// by the current pattern.
+  SetVector<Block *> patternInsertedBlocks;
+
   /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
   /// to the corresponding rewrite objects.
   DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void 
ConversionPatternRewriterImpl::notifyOperationInserted(
   if (!previous.isSet()) {
     // This is a newly created op.
     appendRewrite<CreateOperationRewrite>(op);
+    patternNewOps.insert(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  patternInsertedBlocks.insert(block);
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void 
ConversionPatternRewriter::finalizeOpModification(Operation *op) {
   assert(!impl->wasOpReplaced(op) &&
          "attempting to modify a replaced/erased op");
   PatternRewriter::finalizeOpModification(op);
+  impl->patternModifiedOps.insert(op);
+
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
 #ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
   /// Legalize the resultant IR after successfully applying the given pattern.
   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
                                       ConversionPatternRewriter &rewriter,
-                                      RewriterState &curState);
+                                      const SetVector<Operation *> &newOps,
+                                      const SetVector<Operation *> 
&modifiedOps,
+                                      const SetVector<Block *> 
&insertedBlocks);
 
   /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
   legalizePatternBlockRewrites(Operation *op,
                                ConversionPatternRewriter &rewriter,
                                ConversionPatternRewriterImpl &impl,
-                               RewriterState &state, RewriterState &newState);
-  LogicalResult legalizePatternCreatedOperations(
-      ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-      RewriterState &state, RewriterState &newState);
-  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
-                                           ConversionPatternRewriterImpl &impl,
-                                           RewriterState &state,
-                                           RewriterState &newState);
+                               const SetVector<Block *> &insertedBlocks,
+                               const SetVector<Operation *> &newOps);
+  LogicalResult
+  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
+                                   ConversionPatternRewriterImpl &impl,
+                                   const SetVector<Operation *> &newOps);
+  LogicalResult
+  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
+                             ConversionPatternRewriterImpl &impl,
+                             const SetVector<Operation *> &modifiedOps);
 
   
//===--------------------------------------------------------------------===//
   // Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
   return failure();
 }
 
+/// Helper function that moves and returns the given object. Also resets the
+/// original object, so that it is in a valid, empty state again.
+template <typename T>
+static T moveAndReset(T &obj) {
+  T result = std::move(obj);
+  obj = T();
+  return result;
+}
+
 LogicalResult
 OperationLegalizer::legalizeWithFold(Operation *op,
                                      ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   RewriterState curState = rewriterImpl.getCurrentState();
   auto onFailure = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.config.notifyCallback) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // successfully applied.
   auto onSuccess = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-    auto result = legalizePatternResult(op, pattern, rewriter, curState);
+    SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
+    SetVector<Operation *> modifiedOps =
+        moveAndReset(rewriterImpl.patternModifiedOps);
+    SetVector<Block *> insertedBlocks =
+        moveAndReset(rewriterImpl.patternInsertedBlocks);
+    auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+                                        modifiedOps, insertedBlocks);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, 
const Pattern &pattern,
   return true;
 }
 
-LogicalResult
-OperationLegalizer::legalizePatternResult(Operation *op, const Pattern 
&pattern,
-                                          ConversionPatternRewriter &rewriter,
-                                          RewriterState &curState) {
+LogicalResult OperationLegalizer::legalizePatternResult(
+    Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
+    const SetVector<Operation *> &newOps,
+    const SetVector<Operation *> &modifiedOps,
+    const SetVector<Block *> &insertedBlocks) {
   auto &impl = rewriter.getImpl();
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 
@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation 
*op, const Pattern &pattern,
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  RewriterState newState = impl.getCurrentState();
-  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
-                                          newState)) ||
-      failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
-      failed(legalizePatternCreatedOperations(rewriter, impl, curState,
-                                              newState))) {
+  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
+                                          newOps)) ||
+      failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
+      failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
     return failure();
   }
 
@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation 
*op, const Pattern &pattern,
 
 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     Operation *op, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &impl, RewriterState &state,
-    RewriterState &newState) {
-  SmallPtrSet<Operation *, 16> operationsToIgnore;
+    ConversionPatternRewriterImpl &impl,
+    const SetVector<Block *> &insertedBlocks,
+    const SetVector<Operation *> &newOps) {
+  SmallPtrSet<Operation *, 16> alreadyLegalized;
 
   // If the pattern moved or created any blocks, make sure the types of block
   // arguments get legalized.
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
-    if (!rewrite)
-      continue;
-    Block *block = rewrite->getBlock();
-    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
-            ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
-      continue;
+  for (Block *block : insertedBlocks) {
     // Only check blocks outside of the current operation.
     Operation *parentOp = block->getParentOp();
     if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2322,41 +2352,26 @@ LogicalResult 
OperationLegalizer::legalizePatternBlockRewrites(
       continue;
     }
 
-    // Otherwise, check that this operation isn't one generated by this 
pattern.
-    // This is because we will attempt to legalize the parent operation, and
-    // blocks in regions created by this pattern will already be legalized 
later
-    // on. If we haven't built the set yet, build it now.
-    if (operationsToIgnore.empty()) {
-      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
-           ++i) {
-        auto *createOp =
-            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
-        if (!createOp)
-          continue;
-        operationsToIgnore.insert(createOp->getOperation());
+    // Otherwise, try to legalize the parent operation if it was not generated
+    // by this pattern. This is because we will attempt to legalize the parent
+    // operation, and blocks in regions created by this pattern will already be
+    // legalized later on.
+    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
+      if (failed(legalize(parentOp, rewriter))) {
+        LLVM_DEBUG(logFailure(
+            impl.logger, "operation '{0}'({1}) became illegal after rewrite",
+            parentOp->getName(), parentOp));
+        return failure();
       }
     }
-
-    // If this operation should be considered for re-legalization, try it.
-    if (operationsToIgnore.insert(parentOp).second &&
-        failed(legalize(parentOp, rewriter))) {
-      LLVM_DEBUG(logFailure(impl.logger,
-                            "operation '{0}'({1}) became illegal after 
rewrite",
-                            parentOp->getName(), parentOp));
-      return failure();
-    }
   }
   return success();
 }
 
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-    RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
-    if (!createOp)
-      continue;
-    Operation *op = createOp->getOperation();
+    const SetVector<Operation *> &newOps) {
+  for (Operation *op : newOps) {
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation 
'{0}'({1})",
@@ -2369,12 +2384,8 @@ LogicalResult 
OperationLegalizer::legalizePatternCreatedOperations(
 
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-    RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
-    if (!rewrite)
-      continue;
-    Operation *op = rewrite->getOperation();
+    const SetVector<Operation *> &modifiedOps) {
+  for (Operation *op : modifiedOps) {
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to