[llvm] [mlir] [clang] [clang-tools-extra] Support for dynamic dimensions in 'tensor.splat' (PR #74626)
https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/74626 >From 66287c8d3d23cfd3003baf82160013514b4bedb5 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 5 Dec 2023 22:54:16 -0500 Subject: [PATCH 1/4] Progress in 'tensor.splat' extensions --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 49 +-- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 43 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index f50e3464867be5..60f188607e454c 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [ def Tensor_SplatOp : Tensor_Op<"splat", [ DeclareOpInterfaceMethods, +DeclareOpInterfaceMethods, Pure, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", @@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ let summary = "tensor splat or broadcast operation"; let description = [{ Broadcast the operand to all elements of the result tensor. The operand is -required to be of integer/index/float type, and the result tensor must be -statically shaped. +required to be of integer/index/float type. -Example: +An additional argument of type `index` must be provided for each dynamic +dimension present in the result type. + +Example for a statically shaped tensor: ```mlir %s = arith.constant 10.1 : f32 %t = tensor.splat %s : tensor<8x16xf32> ``` -TODO: This operation is easy to extend to broadcast to dynamically shaped - tensors: +Example for a tensor containing dynamic dimensions: ```mlir -// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding -// to the sizes of the two dynamic dimensions. -%m = "foo"() : () -> (index) -%n = "bar"() : () -> (index) -%t = tensor.splat %s [%m, %n] : tensor +// Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding +// to dimensions 0 and 2 of the resulting tensor, respectively. +%m = arith.constant 10 : index +%n = arith.constant 30 : index +%t = tensor.splat %s[%m, %n] : tensor ``` }]; let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); - let results = (outs AnyStaticShapeTensor:$aggregate); + "integer/index/float type">:$input, + Variadic:$dynamicSizes); + let results = (outs AnyRankedTensor:$aggregate); let builders = [ -OpBuilder<(ins "Value":$element, "Type":$aggregateType), -[{ build($_builder, $_state, aggregateType, element); }]>]; - let assemblyFormat = "$input attr-dict `:` type($aggregate)"; +// Build with an explicit result type and a list of values corresponding +// to the dynamic sizes present in the result type. +OpBuilder<(ins "Value":$element, + "Type":$aggregateType, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + +// Build with a result tensor shape and a list of values corresponding to +// the elements in the result tensor shape set to ShapedType::kDynamic. +OpBuilder<(ins "Value":$element, + "ArrayRef":$staticShape, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + +// Build with mixed static/dynamic sizes, where an attribute represents +// a static dimension and a value represents a dynamic dimension. +OpBuilder<(ins "Value":$element, "ArrayRef":$sizes)> + ]; + + let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)"; let hasFolder = 1; + let hasVerifier = 1; } //===--===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f15695383d34ab..b5e15d8a6d4571 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3430,11 +3430,54 @@ LogicalResult ScatterOp::verify() { // SplatOp //===--===// +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, +Type aggregateType, ValueRange dynamicSizes) { + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, +ArrayRef staticShape, ValueRange dynamicSizes) { + auto aggregateType = RankedTensorType::get(staticShape, element.getType()); + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Va
[clang-tools-extra] [mlir] [llvm] [clang] Support for dynamic dimensions in 'tensor.splat' (PR #74626)
https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/74626 >From 66287c8d3d23cfd3003baf82160013514b4bedb5 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 5 Dec 2023 22:54:16 -0500 Subject: [PATCH 1/5] Progress in 'tensor.splat' extensions --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 49 +-- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 43 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index f50e3464867be5..60f188607e454c 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1728,6 +1728,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [ def Tensor_SplatOp : Tensor_Op<"splat", [ DeclareOpInterfaceMethods, +DeclareOpInterfaceMethods, Pure, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", @@ -1736,38 +1737,56 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ let summary = "tensor splat or broadcast operation"; let description = [{ Broadcast the operand to all elements of the result tensor. The operand is -required to be of integer/index/float type, and the result tensor must be -statically shaped. +required to be of integer/index/float type. -Example: +An additional argument of type `index` must be provided for each dynamic +dimension present in the result type. + +Example for a statically shaped tensor: ```mlir %s = arith.constant 10.1 : f32 %t = tensor.splat %s : tensor<8x16xf32> ``` -TODO: This operation is easy to extend to broadcast to dynamically shaped - tensors: +Example for a tensor containing dynamic dimensions: ```mlir -// Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding -// to the sizes of the two dynamic dimensions. -%m = "foo"() : () -> (index) -%n = "bar"() : () -> (index) -%t = tensor.splat %s [%m, %n] : tensor +// Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding +// to dimensions 0 and 2 of the resulting tensor, respectively. +%m = arith.constant 10 : index +%n = arith.constant 30 : index +%t = tensor.splat %s[%m, %n] : tensor ``` }]; let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); - let results = (outs AnyStaticShapeTensor:$aggregate); + "integer/index/float type">:$input, + Variadic:$dynamicSizes); + let results = (outs AnyRankedTensor:$aggregate); let builders = [ -OpBuilder<(ins "Value":$element, "Type":$aggregateType), -[{ build($_builder, $_state, aggregateType, element); }]>]; - let assemblyFormat = "$input attr-dict `:` type($aggregate)"; +// Build with an explicit result type and a list of values corresponding +// to the dynamic sizes present in the result type. +OpBuilder<(ins "Value":$element, + "Type":$aggregateType, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + +// Build with a result tensor shape and a list of values corresponding to +// the elements in the result tensor shape set to ShapedType::kDynamic. +OpBuilder<(ins "Value":$element, + "ArrayRef":$staticShape, + CArg<"ValueRange", "{}">:$dynamicSizes)>, + +// Build with mixed static/dynamic sizes, where an attribute represents +// a static dimension and a value represents a dynamic dimension. +OpBuilder<(ins "Value":$element, "ArrayRef":$sizes)> + ]; + + let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)"; let hasFolder = 1; + let hasVerifier = 1; } //===--===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f15695383d34ab..b5e15d8a6d4571 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3430,11 +3430,54 @@ LogicalResult ScatterOp::verify() { // SplatOp //===--===// +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, +Type aggregateType, ValueRange dynamicSizes) { + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, +ArrayRef staticShape, ValueRange dynamicSizes) { + auto aggregateType = RankedTensorType::get(staticShape, element.getType()); + build(builder, result, aggregateType, element, dynamicSizes); +} + +void SplatOp::build(OpBuilder &builder, OperationState &result, Va