llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) <details> <summary>Changes</summary> This is similar to the fix to the greedy driver in #<!-- -->153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: ``` %add = arith.addi %add, %add : i64 ``` are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops. --- Full diff: https://github.com/llvm/llvm-project/pull/154038.diff 3 Files Affected: - (modified) mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h (+2) - (modified) mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp (+58-18) - (modified) mlir/test/IR/test-walk-pattern-rewrite-driver.mlir (+10) ``````````diff diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h index 6d62ae3dd43dc..7d5c1d5cebb26 100644 --- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h @@ -27,6 +27,8 @@ namespace mlir { /// This is intended as the simplest and most lightweight pattern rewriter in /// cases when a simple walk gets the job done. /// +/// The driver will skip unreachable blocks. +/// /// Note: Does not apply patterns to the given operation itself. void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index 52f8ea5472883..8f26a294f6d9b 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -27,6 +27,26 @@ namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + Block *block = worklist.pop_back_val(); + while (!worklist.empty()) { + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -90,18 +110,28 @@ void walkAndApplyPatterns(Operation *op, PatternApplicator applicator(patterns); applicator.applyDefaultCostModel(); - // Cursor to track where we're at in the traversal. - struct Cursor { - Cursor(Region *region) : region(region) { + // Iterator on all reachable operations in the region. + // Also keep track if we visited the nested regions of the current op + // already to drive the post-order traversal. + struct RegionReachableOpIterator { + RegionReachableOpIterator(Region *region) : region(region) { regionIt = region->begin(); if (regionIt != region->end()) blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); } - void next() { + // Advance the iterator to the next reachable operation. + void advance() { assert(regionIt != region->end()); hasVisitedRegions = false; if (blockIt == regionIt->end()) { regionIt++; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + regionIt++; + if (regionIt != region->end()) + blockIt = regionIt->begin(); return; } blockIt++; @@ -110,14 +140,23 @@ void walkAndApplyPatterns(Operation *op, << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions()); } } + + // The region we're iterating over. Region *region; + // The Block currently being iterated over. Region::iterator regionIt; + // The Operation currently being iterated over. Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; + // Whether we've visited the nested regions of the current op already. bool hasVisitedRegions = false; }; - SmallVector<Cursor> worklist; + SmallVector<RegionReachableOpIterator> worklist; LDBG() << "Starting walk-based pattern rewrite driver"; + // Perform a post-order traversal of the region, visiting each reachable + // operation. ctx->executeAction<WalkAndApplyPatternsAction>( [&] { for (Region ®ion : op->getRegions()) { @@ -128,36 +167,37 @@ void walkAndApplyPatterns(Operation *op, // Prime the worklist with the entry block of this region. worklist.push_back({®ion}); while (!worklist.empty()) { - Cursor &cursor = worklist.back(); - if (cursor.regionIt == cursor.region->end()) { + RegionReachableOpIterator &it = worklist.back(); + if (it.regionIt == it.region->end()) { // We're done with this region. worklist.pop_back(); continue; } - if (cursor.blockIt == cursor.regionIt->end()) { + if (it.blockIt == it.regionIt->end()) { // We're done with this block. - cursor.regionIt++; - if (cursor.regionIt != cursor.region->end()) - cursor.blockIt = cursor.regionIt->begin(); + it.advance(); continue; } - Operation *op = &*cursor.blockIt; - if (!cursor.hasVisitedRegions) { - cursor.hasVisitedRegions = true; + Operation *op = &*it.blockIt; + // If we haven't visited the nested regions of this op yet, + // enqueue them. + if (!it.hasVisitedRegions) { + it.hasVisitedRegions = true; for (Region &nestedRegion : llvm::reverse(op->getRegions())) { if (nestedRegion.empty()) continue; worklist.push_back({&nestedRegion}); } } - // If we're not at the back of the worklist, we're visiting a nested - // region first. We'll come back to this op later. - if (&cursor != &worklist.back()) + // If we're not at the back of the worklist, we've enqueued some + // nested region for processing. We'll come back to this op later + // (post-order) + if (&it != &worklist.back()) continue; // Premptively increment the cursor, in case the current op // would be erased. - cursor.next(); + it.advance(); LDBG() << "Visiting op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir index c75c478ec3734..1acff6fdf029e 100644 --- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir +++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir @@ -119,3 +119,13 @@ func.func @erase_nested_block() -> i32 { }): () -> (i32) return %a : i32 } + + +// CHECK-LABEL: func.func @unreachable_replace_with_new_op +// CHECK: "test.replace_with_new_op" +func.func @unreachable_replace_with_new_op() { + return +^unreachable: + %a = "test.replace_with_new_op"() : () -> (i32) + return +} `````````` </details> https://github.com/llvm/llvm-project/pull/154038 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits