https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/152912
The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver. >From c2e90f3a39148223619497eeff16ed810e3cab95 Mon Sep 17 00:00:00 2001 From: Matthias Springer <m...@m-sp.org> Date: Sun, 10 Aug 2025 11:41:51 +0000 Subject: [PATCH] [mlir][linalg] Migrate Detensorize pass to new dialect conversion driver --- .../Dialect/Linalg/Transforms/Detensorize.cpp | 34 +++++++++++++++++-- mlir/test/Dialect/Linalg/detensorize_0d.mlir | 7 ++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..221f95a8d8f33 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function<void(Operation *op)> onOperationErased, + std::function<void(Block *block)> onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function<void(Operation *op)> onOperationErased; + std::function<void(Block *block)> onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,22 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + auto onOperationErased = [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }; + auto onBlockErased = [&](Block *block) { + for (BlockArgument arg : block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }; + CallbackListener listener(onOperationErased, onBlockErased); + + config.listener = &listener; + config.allowPatternRollback = false; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir index 74931cb0830bc..5c29b04630cad 100644 --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso } // CHECK-LABEL: func @detensor_op_sequence // CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>) -// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]] // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] -// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] -// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] +// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]] +// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] // CHECK: return %[[new_tensor_res]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits