https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/154038
>From e5070afea324c3fa3bde2b2da3e605a7d8c43c7c Mon Sep 17 00:00:00 2001 From: Mehdi Amini <joker....@gmail.com> Date: Sun, 17 Aug 2025 14:24:35 -0700 Subject: [PATCH] [MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: %add = arith.addi %add, %add : i64 are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops. --- .../Transforms/WalkPatternRewriteDriver.h | 2 ++ .../Utils/WalkPatternRewriteDriver.cpp | 27 +++++++++++++++++++ .../IR/test-walk-pattern-rewrite-driver.mlir | 10 +++++++ 3 files changed, 39 insertions(+) diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h index 6d62ae3dd43dc..7d5c1d5cebb26 100644 --- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h @@ -27,6 +27,8 @@ namespace mlir { /// This is intended as the simplest and most lightweight pattern rewriter in /// cases when a simple walk gets the job done. /// +/// The driver will skip unreachable blocks. +/// /// Note: Does not apply patterns to the given operation itself. void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index 03a6e59aab4d9..3e02ab20aa53d 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -27,6 +27,26 @@ namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + Block *block = worklist.pop_back_val(); + while (!worklist.empty()) { + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -98,6 +118,8 @@ void walkAndApplyPatterns(Operation *op, regionIt = region->begin(); if (regionIt != region->end()) blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); } // Advance the iterator to the next reachable operation. void advance() { @@ -105,6 +127,9 @@ void walkAndApplyPatterns(Operation *op, hasVisitedRegions = false; if (blockIt == regionIt->end()) { regionIt++; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + regionIt++; if (regionIt != region->end()) blockIt = regionIt->begin(); return; @@ -121,6 +146,8 @@ void walkAndApplyPatterns(Operation *op, Region::iterator regionIt; // The Operation currently being iterated over. Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; // Whether we've visited the nested regions of the current op already. bool hasVisitedRegions = false; }; diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir index c75c478ec3734..1acff6fdf029e 100644 --- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir +++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir @@ -119,3 +119,13 @@ func.func @erase_nested_block() -> i32 { }): () -> (i32) return %a : i32 } + + +// CHECK-LABEL: func.func @unreachable_replace_with_new_op +// CHECK: "test.replace_with_new_op" +func.func @unreachable_replace_with_new_op() { + return +^unreachable: + %a = "test.replace_with_new_op"() : () -> (i32) + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits