Issue |
133964
|
Summary |
[mlir][Bufferization] Bufferization does not seem to handle well degenerate slices.
|
Labels |
mlir,
mlir:bufferization
|
Assignees |
|
Reporter |
MaheshRavishankar
|
I have long had this issue with bufferization where it somehow needs to always need a `tensor.extract_slice` to "do the right thing" and gets thrown off if the `extract_slice` isnt there.
For example, take these two examples
```
func.func @test(%14: index, %0 : memref<8x16xf16>, %1 : memref<8xi32>, %2 : memref<?x16xf16>) {
%16 = bufferization.to_tensor %0 restrict : memref<8x16xf16> to tensor<8x16xf16>
%17 = bufferization.to_tensor %1 restrict : memref<8xi32> to tensor<8xi32>
%18 = bufferization.to_tensor %2 restrict : memref<?x16xf16> to tensor<?x16xf16>
%19 = scf.forall (%arg0) in (2) shared_outs(%arg1 = %18) -> (tensor<?x16xf16>) {
%20 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
%extracted_slice = tensor.extract_slice %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x16xf16> to tensor<?x8xf16>
%21 = scf.forall (%arg2, %arg3) in (8, 1) shared_outs(%arg4 = %extracted_slice) -> (tensor<?x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %16[%arg2, %20] [1, 8] [1, 1] : tensor<8x16xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %17[%arg2] [1] [1] : tensor<8xi32> to tensor<1xi32>
%22 = iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1xi32>) outs(%arg4 : tensor<?x8xf16>)\
{
^bb0(%arg5: f16, %arg6: f16):
iree_linalg_ext.yield %arg5 : f16
} -> tensor<?x8xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x8xf16>
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %21 into %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x16xf16>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
bufferization.materialize_in_destination
%19 in restrict writable %2 : (tensor<?x16xf16>, memref<?x16xf16>) -> ()
return
}
// -----
func.func @test(%14: index, %0 : memref<8x16xf16>, %1 : memref<8xi32>, %2 : memref<?x16xf16>) {
%16 = bufferization.to_tensor %0 restrict : memref<8x16xf16> to tensor<8x16xf16>
%17 = bufferization.to_tensor %1 restrict : memref<8xi32> to tensor<8xi32>
%18 = bufferization.to_tensor %2 restrict : memref<?x16xf16> to tensor<?x16xf16>
%19 = scf.forall (%arg0) in (2) shared_outs(%arg1 = %18) -> (tensor<?x16xf16>) {
%20 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
%extracted_slice = tensor.extract_slice %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x16xf16> to tensor<?x8xf16>
%21 = scf.forall (%arg2, %arg3) in (8, 1) shared_outs(%arg4 = %extracted_slice) -> (tensor<?x8xf16>) {
%extracted_slice_0 = tensor.extract_slice %16[%arg2, %20] [1, 8] [1, 1] : tensor<8x16xf16> to tensor<1x8xf16>
%extracted_slice_1 = tensor.extract_slice %17[%arg2] [1] [1] : tensor<8xi32> to tensor<1xi32>
%extracted_slice_2 = tensor.extract_slice %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> to tensor<?x8xf16>
%22 = iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1xi32>) outs(%extracted_slice_2 : ten\
sor<?x8xf16>) {
^bb0(%arg5: f16, %arg6: f16):
iree_linalg_ext.yield %arg5 : f16
} -> tensor<?x8xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %22 into %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x8xf16>
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
scf.forall.in_parallel {
tensor.parallel_insert_slice %21 into %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x16xf16>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
bufferization.materialize_in_destination
%19 in restrict writable %2 : (tensor<?x16xf16>, memref<?x16xf16>) -> ()
return
```
The only difference here is the degenerate `extract_slice` that I manually added
```
%extracted_slice_2 = tensor.extract_slice %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> to tensor<?x8xf16>
```
This is effectively a no-op/foldable extract_slice. But the two produce different outputs
```
module {
func.func @test(%arg0: index, %arg1: memref<8x16xf16>, %arg2: memref<8xi32>, %arg3: memref<?x16xf16>) {
%c0 = arith.constant 0 : index
%dim = memref.dim %arg3, %c0 : memref<?x16xf16>
%alloc = memref.alloc(%dim) : memref<?x16xf16>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg3 : memref<?x16xf16>) outs(%alloc : memref<?x16xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
scf.forall (%arg4) in (2) {
%0 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg4)
%subview = memref.subview %alloc[0, %0] [%arg0, 8] [1, 1] : memref<?x16xf16> to memref<?x8xf16, strided<[16, 1], offset: ?>>
scf.forall (%arg5, %arg6) in (8, 1) {
%subview_0 = memref.subview %arg1[%arg5, %0] [1, 8] [1, 1] : memref<8x16xf16> to memref<1x8xf16, strided<[16, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%arg5] [1] [1] : memref<8xi32> to memref<1xi32, strided<[1], offset: ?>>
%alloc_2 = memref.alloc(%arg0) : memref<?x8xf16>
iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%subview_0, %subview_1 : memref<1x8xf16, strided<[16, 1], offset: ?>>, memref<1xi32, strided<[1], offset: ?>>) outs(%alloc_2 : memref<?x8xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc_2 : memref<?x8xf16>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc : memref<?x16xf16>) outs(%arg3 : memref<?x16xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
return
}
}
// -----
module {
func.func @test(%arg0: index, %arg1: memref<8x16xf16>, %arg2: memref<8xi32>, %arg3: memref<?x16xf16>) {
%c0 = arith.constant 0 : index
%dim = memref.dim %arg3, %c0 : memref<?x16xf16>
%alloc = memref.alloc(%dim) : memref<?x16xf16>
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg3 : memref<?x16xf16>) outs(%alloc : memref<?x16xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
scf.forall (%arg4) in (2) {
%0 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg4)
%subview = memref.subview %alloc[0, %0] [%arg0, 8] [1, 1] : memref<?x16xf16> to memref<?x8xf16, strided<[16, 1], offset: ?>>
scf.forall (%arg5, %arg6) in (8, 1) {
%subview_0 = memref.subview %arg1[%arg5, %0] [1, 8] [1, 1] : memref<8x16xf16> to memref<1x8xf16, strided<[16, 1], offset: ?>>
%subview_1 = memref.subview %arg2[%arg5] [1] [1] : memref<8xi32> to memref<1xi32, strided<[1], offset: ?>>
iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%subview_0, %subview_1 : memref<1x8xf16, strided<[16, 1], offset: ?>>, memref<1xi32, strided<[1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<\[16, 1], offset: ?>>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
}
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc : memref<?x16xf16>) outs(%arg3 : memref<?x16xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
}
return
}
}
```
(Ignore the outer `alloca` that I dont know how to avoid and is just an artifiact of the test). The inner-alloca created in the first example is completely unnecessary. Bufferization has to be able to deal with the slice not being there.
@matthias-springer can you suggest how I can go about fixing this issue?
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs