Author: Sean Silva Date: 2020-11-30T17:04:14-08:00 New Revision: 774f1d3ffd458d6cb82d5039758ef1cf6370957f
URL: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f DIFF: https://github.com/llvm/llvm-project/commit/774f1d3ffd458d6cb82d5039758ef1cf6370957f.diff LOG: [mlir] Small cleanups to func-bufferize/finalizing-bufferize - Address TODO in scf-bufferize: the argument materialization issue is now fixed and the code is now in Transforms/Bufferize.cpp - Tighten up finalizing-bufferize to avoid creating invalid IR when operand types potentially change - Tidy up the testing of func-bufferize, and move appropriate tests to a new finalizing-bufferize.mlir - The new stricter checking in finalizing-bufferize revealed that we needed a DimOp conversion pattern (found when integrating into npcomp). Previously, the converion infrastructure was blindly changing the operand type during finalization, which happened to work due to DimOp's tensor/memref polymorphism, but is generally not encouraged (the new pattern is the way to tell the conversion infrastructure that it is legal to change that type). Added: mlir/test/Transforms/finalizing-bufferize.mlir Modified: mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp mlir/lib/Transforms/Bufferize.cpp mlir/test/Dialect/Standard/bufferize.mlir mlir/test/Dialect/Standard/func-bufferize.mlir Removed: mlir/test/Dialect/Standard/func-bufferize-partial.mlir ################################################################################ diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 57d605b3491f7..7cf0dfabd9174 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -27,21 +27,6 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> { OwningRewritePatternList patterns; ConversionTarget target(*context); - // TODO: Move this to BufferizeTypeConverter's constructor. - // - // This doesn't currently play well with "finalizing" bufferizations (ones - // that expect all materializations to be gone). In particular, there seems - // to at least be a double-free in the dialect conversion framework - // when this materialization gets inserted and then folded away because - // it is marked as illegal. - typeConverter.addArgumentMaterialization( - [](OpBuilder &builder, RankedTensorType type, ValueRange inputs, - Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa<BaseMemRefType>()); - return builder.create<TensorLoadOp>(loc, type, inputs[0]); - }); - populateBufferizeMaterializationLegality(target); populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, patterns, target); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 9056fbc25e14d..8b47e88677e2d 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -20,6 +20,21 @@ using namespace mlir; +namespace { +class BufferizeDimOp : public OpConversionPattern<DimOp> { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(DimOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + DimOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp<DimOp>(op, adaptor.memrefOrTensor(), + adaptor.index()); + return success(); + } +}; +} // namespace + namespace { class BufferizeDynamicTensorFromElementsOp : public OpConversionPattern<DynamicTensorFromElementsOp> { @@ -148,6 +163,7 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert< // clang-format off + BufferizeDimOp, BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp, BufferizeSelectOp, @@ -178,6 +194,8 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { return typeConverter.isLegal(op.getType()) || !op.condition().getType().isa<IntegerType>(); }); + target.addDynamicallyLegalOp<DimOp>( + [&](DimOp op) { return typeConverter.isLegal(op); }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp index 1811ac8bdfbca..66b1cc65646c1 100644 --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -105,13 +105,17 @@ struct FinalizingBufferizePass populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, patterns); - target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>(); // If all result types are legal, and all block arguments are legal (ensured // by func conversion above), then all types in the program are legal. - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); - }); + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents + // populateEliminateBufferizeMaterializationsPatterns from updating the + // types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 8cc05ff20644b..27769c52d9ea4 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,5 +1,16 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, +// CHECK-SAME: %[[INDEX:.*]]: index) -> index { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> +// CHECK: %[[EXTENT:.*]] = dim %[[MEMREF]], %[[INDEX]] : memref<f32> +// CHECK: return %[[EXTENT]] : index +func @dim(%arg0: tensor<f32>, %arg1: index) -> index { + %0 = dim %arg0, %arg1 : tensor<f32> + return %0 : index +} + // CHECK-LABEL: func @dynamic_tensor_from_elements( // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> { @@ -7,7 +18,8 @@ // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32> +// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex> // CHECK: scf.yield // CHECK: } diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir deleted file mode 100644 index 43ea4591e4e35..0000000000000 --- a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir +++ /dev/null @@ -1,59 +0,0 @@ -// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func @block_arguments( -// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { -// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32> -// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32> -// CHECK: br ^bb1(%[[M1]] : memref<f32>) -// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>): -// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32> -// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32> -// CHECK: return %[[M2]] : memref<f32> -func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> { - br ^bb1(%arg0: tensor<f32>) -^bb1(%bbarg: tensor<f32>): - return %bbarg : tensor<f32> -} - -// CHECK-LABEL: func @partial() -// CHECK-SAME: memref<f32> -func @partial() -> tensor<f32> { - // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32> - // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32> - %0 = "test.source"() : () -> tensor<f32> - // CHECK-NEXT: return %[[MEM]] : memref<f32> - return %0 : tensor<f32> -} - -// CHECK-LABEL: func @region_op -// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32> -func @region_op(%arg0: i1) -> tensor<f32> { - // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>) - %0 = scf.if %arg0 -> (tensor<f32>) { - // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32> - %1 = "test.source"() : () -> tensor<f32> - // CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32> - scf.yield %1 : tensor<f32> - // CHECK-NEXT: else - } else { - // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32> - %1 = "test.other_source"() : () -> tensor<f32> - // CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32> - scf.yield %1 : tensor<f32> - } - // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32> - // CHECK: return %[[MEM]] : memref<f32> - return %0 : tensor<f32> -} - -// ----- - -func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> { - %0 = constant true - cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>) - ^bb1(%bbarg0: tensor<f32>): - // expected-error @+1 {{failed to legalize operation 'test.terminator'}} - "test.terminator"() : () -> () - ^bb2(%bbarg1: tensor<f32>): - return %bbarg1 : tensor<f32> -} diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir index d02db99aecd83..de2f75c4a293b 100644 --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -1,39 +1,29 @@ -// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @identity( -// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { -// CHECK: return %[[ARG]] : memref<f32> +// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { +// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32> +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> +// CHECK: return %[[MEMREF]] : memref<f32> func @identity(%arg0: tensor<f32>) -> tensor<f32> { return %arg0 : tensor<f32> } // CHECK-LABEL: func @block_arguments( // CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { -// CHECK: br ^bb1(%[[ARG]] : memref<f32>) +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32> +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32> +// CHECK: br ^bb1(%[[M1]] : memref<f32>) // CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>): -// CHECK: return %[[BBARG]] : memref<f32> +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32> +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32> +// CHECK: return %[[M2]] : memref<f32> func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> { br ^bb1(%arg0: tensor<f32>) ^bb1(%bbarg: tensor<f32>): return %bbarg : tensor<f32> } -// CHECK-LABEL: func @eliminate_target_materialization( -// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { -// CHECK: return %[[ARG]] : memref<f32> -func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> { - %0 = tensor_to_memref %arg0 : memref<f32> - return %0 : memref<f32> -} - -// CHECK-LABEL: func @eliminate_source_materialization( -// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { -// CHECK: return %[[ARG]] : memref<f32> -func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> { - %0 = tensor_load %arg0 : memref<f32> - return %0 : tensor<f32> -} - // CHECK-LABEL: func private @source() -> memref<f32> // CHECK-LABEL: func @call_source() -> memref<f32> { // CHECK: %[[RET:.*]] = call @source() : () -> memref<f32> @@ -43,11 +33,11 @@ func @call_source() -> tensor<f32> { %0 = call @source() : () -> tensor<f32> return %0 : tensor<f32> } - -// CHECK-LABEL: func private @sink(memref<f32>) // CHECK-LABEL: func @call_sink( -// CHECK-SAME: %[[ARG:.*]]: memref<f32>) { -// CHECK: call @sink(%[[ARG]]) : (memref<f32>) -> () +// CHECK-SAME: %[[ARG:.*]]: memref<f32>) { +// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref<f32> +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> +// CHECK: call @sink(%[[MEMREF]]) : (memref<f32>) -> () // CHECK: return func private @sink(tensor<f32>) func @call_sink(%arg0: tensor<f32>) { @@ -55,10 +45,25 @@ func @call_sink(%arg0: tensor<f32>) { return } +// CHECK-LABEL: func @unconverted_op_in_body() -> memref<f32> { +// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> tensor<f32> +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> +// CHECK: return %[[MEMREF]] : memref<f32> +func @unconverted_op_in_body() -> tensor<f32> { + %0 = "test.source"() : () -> tensor<f32> + return %0 : tensor<f32> +} + // ----- -func @failed_to_legalize() -> tensor<f32> { - // expected-error @+1 {{failed to legalize operation 'test.source'}} - %0 = "test.source"() : () -> (tensor<f32>) - return %0 : tensor<f32> +// Because this pass updates block arguments, it needs to also atomically +// update all terminators and issue an error if that is not possible. +func @unable_to_update_terminator(%arg0: tensor<f32>) -> tensor<f32> { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>) + ^bb1(%bbarg0: tensor<f32>): + // expected-error @+1 {{failed to legalize operation 'test.terminator'}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor<f32>): + return %bbarg1 : tensor<f32> } diff --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir new file mode 100644 index 0000000000000..5c09664776ead --- /dev/null +++ b/mlir/test/Transforms/finalizing-bufferize.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @eliminate_materializations( +// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> { +// CHECK: return %[[ARG]] : memref<f32> +func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> { + %0 = tensor_load %arg0 : memref<f32> + %1 = tensor_to_memref %0 : memref<f32> + return %1 : memref<f32> +} + +// ----- + +func @unable_to_convert_lone_tensor_to_memref() -> memref<f32> { + // expected-error @+1 {{failed to legalize operation 'test.source'}} + %0 = "test.source"() : () -> tensor<f32> + %1 = tensor_to_memref %0 : memref<f32> + return %1 : memref<f32> +} + +// ----- + +func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) { + %0 = tensor_load %arg0 : memref<f32> + // expected-error @+1 {{failed to legalize operation 'test.sink'}} + "test.sink"(%0) : (tensor<f32>) -> () + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits