https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/74626
>From 66287c8d3d23cfd3003baf82160013514b4bedb5 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena <ru...@mathworks.com> Date: Tue, 5 Dec 2023 22:54:16 -0500 Subject: [PATCH 1/4] Progress in 'tensor.splat' extensions --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 49 +++++++++++++------ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 43 ++++++++++++++++ 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index f50e3464867be5..60f188607e454c 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [ def Tensor_SplatOp : Tensor_Op<"splat", [ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>, + DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, Pure, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", @@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ let summary = "tensor splat or broadcast operation"; let description = [{ Broadcast the operand to all elements of the result tensor. The operand is - required to be of integer/index/float type, and the result tensor must be - statically shaped. + required to be of integer/index/float type. - Example: + An additional argument of type `index` must be provided for each dynamic + dimension present in the result type. + + Example for a statically shaped tensor: ```mlir %s = arith.constant 10.1 : f32 %t = tensor.splat %s : tensor<8x16xf32> ``` - TODO: This operation is easy to extend to broadcast to dynamically shaped - tensors: + Example for a tensor containing dynamic dimensions: ```mlir - // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding - // to the sizes of the two dynamic dimensions. - %m = "foo"() : () -> (index) - %n = "bar"() : () -> (index) - %t = tensor.splat %s [%m, %n] : tensor<?x?xf32> + // Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding + // to dimensions 0 and 2 of the resulting tensor, respectively. + %m = arith.constant 10 : index + %n = arith.constant 30 : index + %t = tensor.splat %s[%m, %n] : tensor<?x20x?xf32> ``` }]; let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); - let results = (outs AnyStaticShapeTensor:$aggregate); + "integer/index/float type">:$input, + Variadic<Index>:$dynamicSizes); + let results = (outs AnyRankedTensor:$aggregate); let builders = [ - OpBuilder<(ins "Value":$element, "Type":$aggregateType), - [{ build($_builder, $_state, aggregateType, element); }]>]; - let assemblyFormat = "$input attr-dict `:` type($aggregate)"; + // Build with an explicit result type and a list of values corresponding + // to the dynamic sizes present in the result type. + OpBuilder<(ins "Value":$element, + "Type":$aggregateType, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + + // Build with a result tensor shape and a list of values corresponding to + // the elements in the result tensor shape set to ShapedType::kDynamic. + OpBuilder<(ins "Value":$element, + "ArrayRef<int64_t>":$staticShape, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + + // Build with mixed static/dynamic sizes, where an attribute represents + // a static dimension and a value represents a dynamic dimension. + OpBuilder<(ins "Value":$element, "ArrayRef<OpFoldResult>":$sizes)> + ]; + + let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)"; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f15695383d34ab..b5e15d8a6d4571 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3430,11 +3430,54 @@ LogicalResult ScatterOp::verify() { // SplatOp //===----------------------------------------------------------------------===// +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, + Type aggregateType, ValueRange dynamicSizes) { + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, + ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) { + auto aggregateType = RankedTensorType::get(staticShape, element.getType()); + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, + ArrayRef<OpFoldResult> sizes) { + SmallVector<int64_t> staticShape; + SmallVector<Value> dynamicSizes; + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape); + build(builder, result, element, staticShape, dynamicSizes); +} + void SplatOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { setNameFn(getResult(), "splat"); } +LogicalResult SplatOp::verify() { + if (getType().getNumDynamicDims() != + static_cast<int64_t>(getDynamicSizes().size())) + return emitOpError("incorrect number of dynamic sizes, has ") + << getDynamicSizes().size() << ", expected " + << getType().getNumDynamicDims(); + return success(); +} + +LogicalResult +SplatOp::reifyResultShapes(OpBuilder &builder, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); + unsigned ctr = 0; + for (int64_t i = 0; i < getType().getRank(); ++i) { + if (getType().isDynamicDim(i)) { + reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++]; + } else { + reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); + } + } + return success(); +} + OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { auto constOperand = adaptor.getInput(); if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) >From 823899dd3977768c4d99c63efc4d356661f2d46d Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena <ru...@mathworks.com> Date: Wed, 6 Dec 2023 09:49:22 -0500 Subject: [PATCH 2/4] Added unit tests --- mlir/test/Dialect/Tensor/bufferize.mlir | 20 ++++++++++++++++++++ mlir/test/Dialect/Tensor/invalid.mlir | 8 ++++++++ mlir/test/Dialect/Tensor/ops.mlir | 10 ++++++++++ 3 files changed, 38 insertions(+) diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index a8b3c6af9ae893..e3c6ebbeb9d916 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -602,3 +602,23 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { %t = tensor.splat %f : tensor<10x2x4xf32> return %t : tensor<10x2x4xf32> } + +// ----- + +// CHECK-LABEL: func @tensor.splat.dynamic( +// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32 +// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[N:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>) +// CHECK: () { +// CHECK: linalg.yield %[[F]] : f32 +// CHECK: } +// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32> +// CHECK: } +func.func @tensor.splat.dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf32> { + %0 = tensor.splat %f[%m, %n] : tensor<?x3x?xf32> + return %0 : tensor<?x3x?xf32> +} + diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 9b6c2327879cf9..943a6df16ce01d 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -456,6 +456,14 @@ func.func @invalid_splat(%v : vector<8xf32>) { // ----- +func.func @invalid_splat(%v: f32, %m: index) { + // expected-error@+1 {{incorrect number of dynamic sizes, has 1, expected 2}} + %w = tensor.splat %v[%m] : tensor<?x8x?xf32> + return +} + +// ----- + func.func @gather_empty_dims( %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { // expected-error@+1 {{gather_dims must be non-empty}} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 2282da38803af0..2b0a74acce0826 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -309,6 +309,16 @@ func.func @test_splat_op(%s : f32) { return } +// CHECK-LABEL: func @test_splat_op +// CHECK-SAME: [[S:arg[0-9]+]]: f32 +// CHECK-SAME: [[M:arg[0-9]+]]: index +// CHECK-SAME: [[N:arg[0-9]+]]: index +func.func @test_splat_op_dynamic(%s: f32, %m: index, %n: index) { + // CHECK: tensor.splat %[[S]][%[[M]], %[[N]]] : tensor<?x8x?xf32> + %v = tensor.splat %s[%m, %n] : tensor<?x8x?xf32> + return +} + // ----- // CHECK-LABEL: func.func @gather_scatter( >From 8e6a26b2e3d2cf5bbfa71f6c8da462a69299c1d3 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena <ru...@mathworks.com> Date: Wed, 6 Dec 2023 10:43:49 -0500 Subject: [PATCH 3/4] Added unit tests for no-fold cases --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++- mlir/test/Dialect/Tensor/canonicalize.mlir | 40 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index b5e15d8a6d4571..8fad57eea64aef 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1736,7 +1736,7 @@ class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> { LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>(); - if (!splatOp) + if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape()) return failure(); rewriter.replaceOpWithNewOp<tensor::SplatOp>( @@ -3483,6 +3483,10 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) return {}; + // Do not fold if the splat is not statically shaped + if (!getType().hasStaticShape()) + return {}; + // SplatElementsAttr::get treats single value for second arg as being a // splat. return SplatElementsAttr::get(getType(), {constOperand}); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 84c44a09aa3dd1..6b86341911f590 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1204,6 +1204,19 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> { // ----- +// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold +// CHECK-SAME: %[[F:.+]]: f32 +// CHECK-SAME: %[[M:.+]]: index +func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> { + // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] + // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]] + %c0 = tensor.splat %arg[%m] : tensor<2x?xf32> + %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32> + return %0 : tensor<2x2x?xf32> +} + +// ----- + func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> { %c0 = tensor.splat %arg : tensor<2x2x2xf32> %0 = tensor.collapse_shape %c0 [[0], [1, 2]] @@ -1217,6 +1230,20 @@ func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> { // CHECK: return %[[CST]] // ----- + +// CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold +// CHECK-SAME: %[[F:.+]]: f32 +// CHECK-SAME: %[[M:.+]]: index +func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> { + // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]] + %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32> + %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32> + return %0 : tensor<2x?xf32> +} + +// ----- + func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> { %c0 = arith.constant dense<42> : tensor<2x8xi16> %0 = tensor.expand_shape %c0 [[0], [1, 2]] @@ -1627,6 +1654,19 @@ func.func @splat_fold() -> tensor<4xf32> { // ----- +// CHECK-LABEL: func @splat_dynamic_no_fold +// CHECK-SAME: %[[M:.+]]: index +func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> { + // CHECK: %[[F:.+]] = arith.constant + %f = arith.constant 1.0 : f32 + + // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32> + %t = tensor.splat %f[%m] : tensor<4x?xf32> + return %t : tensor<4x?xf32> +} + +// ----- + // There was an issue in cast + insert_slice folding generating invalid ir. // https://github.com/llvm/llvm-project/issues/53099 // CHECK-LABEL: func @insert_slice_cast >From 1fdb7487fabe5bb8c775e343091ae9442100de7b Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena <ru...@mathworks.com> Date: Wed, 6 Dec 2023 12:09:29 -0500 Subject: [PATCH 4/4] Changed 'tensor.splat' example --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 60f188607e454c..251a53ed3b888d 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1745,7 +1745,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ Example for a statically shaped tensor: ```mlir - %s = arith.constant 10.1 : f32 + %s = arith.constant 1.0 : f32 %t = tensor.splat %s : tensor<8x16xf32> ``` _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits