Issue 133964
Summary [mlir][Bufferization] Bufferization does not seem to handle well degenerate slices.
Labels mlir, mlir:bufferization
Assignees
Reporter MaheshRavishankar
    I have long had this issue with bufferization where it somehow needs to always need a `tensor.extract_slice` to "do the right thing" and gets thrown off if the `extract_slice` isnt there.

For example, take these two examples

```
func.func @test(%14: index, %0 : memref<8x16xf16>, %1 : memref<8xi32>, %2 : memref<?x16xf16>) {
  %16 = bufferization.to_tensor %0 restrict : memref<8x16xf16> to tensor<8x16xf16>
  %17 = bufferization.to_tensor %1 restrict : memref<8xi32> to tensor<8xi32>
  %18 = bufferization.to_tensor %2 restrict : memref<?x16xf16> to tensor<?x16xf16>
 %19 = scf.forall (%arg0) in (2) shared_outs(%arg1 = %18) -> (tensor<?x16xf16>) {
    %20 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %extracted_slice = tensor.extract_slice %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x16xf16> to tensor<?x8xf16>
    %21 = scf.forall (%arg2, %arg3) in (8, 1) shared_outs(%arg4 = %extracted_slice) -> (tensor<?x8xf16>) {
      %extracted_slice_0 = tensor.extract_slice %16[%arg2, %20] [1, 8] [1, 1] : tensor<8x16xf16> to tensor<1x8xf16>
      %extracted_slice_1 = tensor.extract_slice %17[%arg2] [1] [1] : tensor<8xi32> to tensor<1xi32>
 %22 = iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1xi32>) outs(%arg4 : tensor<?x8xf16>)\
 {
      ^bb0(%arg5: f16, %arg6: f16):
 iree_linalg_ext.yield %arg5 : f16
      } -> tensor<?x8xf16>
 scf.forall.in_parallel {
        tensor.parallel_insert_slice %22 into %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x8xf16>
      }
 } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
 scf.forall.in_parallel {
      tensor.parallel_insert_slice %21 into %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x16xf16>
 }
  } {mapping = [#iree_codegen.workgroup_mapping<x>]}
 bufferization.materialize_in_destination
    %19 in restrict writable %2 : (tensor<?x16xf16>, memref<?x16xf16>) -> ()
  return
}

// -----

func.func @test(%14: index, %0 : memref<8x16xf16>, %1 : memref<8xi32>, %2 : memref<?x16xf16>) {
  %16 = bufferization.to_tensor %0 restrict : memref<8x16xf16> to tensor<8x16xf16>
  %17 = bufferization.to_tensor %1 restrict : memref<8xi32> to tensor<8xi32>
  %18 = bufferization.to_tensor %2 restrict : memref<?x16xf16> to tensor<?x16xf16>
 %19 = scf.forall (%arg0) in (2) shared_outs(%arg1 = %18) -> (tensor<?x16xf16>) {
    %20 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg0)
    %extracted_slice = tensor.extract_slice %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x16xf16> to tensor<?x8xf16>
    %21 = scf.forall (%arg2, %arg3) in (8, 1) shared_outs(%arg4 = %extracted_slice) -> (tensor<?x8xf16>) {
      %extracted_slice_0 = tensor.extract_slice %16[%arg2, %20] [1, 8] [1, 1] : tensor<8x16xf16> to tensor<1x8xf16>
      %extracted_slice_1 = tensor.extract_slice %17[%arg2] [1] [1] : tensor<8xi32> to tensor<1xi32>
 %extracted_slice_2 = tensor.extract_slice %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> to tensor<?x8xf16>
      %22 = iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%extracted_slice_0, %extracted_slice_1 : tensor<1x8xf16>, tensor<1xi32>) outs(%extracted_slice_2 : ten\
sor<?x8xf16>) {
      ^bb0(%arg5: f16, %arg6: f16):
 iree_linalg_ext.yield %arg5 : f16
      } -> tensor<?x8xf16>
 scf.forall.in_parallel {
        tensor.parallel_insert_slice %22 into %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x8xf16>
      }
 } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
 scf.forall.in_parallel {
      tensor.parallel_insert_slice %21 into %arg1[0, %20] [%14, 8] [1, 1] : tensor<?x8xf16> into tensor<?x16xf16>
 }
  } {mapping = [#iree_codegen.workgroup_mapping<x>]}
 bufferization.materialize_in_destination
    %19 in restrict writable %2 : (tensor<?x16xf16>, memref<?x16xf16>) -> ()
  return
```

The only difference here is the degenerate `extract_slice` that I manually added

```
%extracted_slice_2 = tensor.extract_slice %arg4[0, 0] [%14, 8] [1, 1] : tensor<?x8xf16> to tensor<?x8xf16>
```

This is effectively a no-op/foldable extract_slice. But the two produce different outputs

```
module {
  func.func @test(%arg0: index, %arg1: memref<8x16xf16>, %arg2: memref<8xi32>, %arg3: memref<?x16xf16>) {
    %c0 = arith.constant 0 : index
    %dim = memref.dim %arg3, %c0 : memref<?x16xf16>
    %alloc = memref.alloc(%dim) : memref<?x16xf16>
 linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg3 : memref<?x16xf16>) outs(%alloc : memref<?x16xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    }
 scf.forall (%arg4) in (2) {
      %0 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg4)
      %subview = memref.subview %alloc[0, %0] [%arg0, 8] [1, 1] : memref<?x16xf16> to memref<?x8xf16, strided<[16, 1], offset: ?>>
 scf.forall (%arg5, %arg6) in (8, 1) {
        %subview_0 = memref.subview %arg1[%arg5, %0] [1, 8] [1, 1] : memref<8x16xf16> to memref<1x8xf16, strided<[16, 1], offset: ?>>
        %subview_1 = memref.subview %arg2[%arg5] [1] [1] : memref<8xi32> to memref<1xi32, strided<[1], offset: ?>>
        %alloc_2 = memref.alloc(%arg0) : memref<?x8xf16>
 iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%subview_0, %subview_1 : memref<1x8xf16, strided<[16, 1], offset: ?>>, memref<1xi32, strided<[1], offset: ?>>) outs(%alloc_2 : memref<?x8xf16>) {
 ^bb0(%arg7: f16, %arg8: f16):
          iree_linalg_ext.yield %arg7 : f16
        }
        linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc_2 : memref<?x8xf16>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
        ^bb0(%in: f16, %out: f16):
          linalg.yield %in : f16
        }
      } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
      linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
      ^bb0(%in: f16, %out: f16):
 linalg.yield %in : f16
      }
    } {mapping = [#iree_codegen.workgroup_mapping<x>]}
    linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc : memref<?x16xf16>) outs(%arg3 : memref<?x16xf16>) {
    ^bb0(%in: f16, %out: f16):
 linalg.yield %in : f16
    }
    return
  }
}
// -----
module {
 func.func @test(%arg0: index, %arg1: memref<8x16xf16>, %arg2: memref<8xi32>, %arg3: memref<?x16xf16>) {
    %c0 = arith.constant 0 : index
    %dim = memref.dim %arg3, %c0 : memref<?x16xf16>
    %alloc = memref.alloc(%dim) : memref<?x16xf16>
    linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg3 : memref<?x16xf16>) outs(%alloc : memref<?x16xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    }
 scf.forall (%arg4) in (2) {
      %0 = affine.apply affine_map<(d0) -> (d0 * 8)>(%arg4)
      %subview = memref.subview %alloc[0, %0] [%arg0, 8] [1, 1] : memref<?x16xf16> to memref<?x8xf16, strided<[16, 1], offset: ?>>
 scf.forall (%arg5, %arg6) in (8, 1) {
        %subview_0 = memref.subview %arg1[%arg5, %0] [1, 8] [1, 1] : memref<8x16xf16> to memref<1x8xf16, strided<[16, 1], offset: ?>>
        %subview_1 = memref.subview %arg2[%arg5] [1] [1] : memref<8xi32> to memref<1xi32, strided<[1], offset: ?>>
        iree_linalg_ext.scatter {lowering_config = #iree_gpu.lowering_config<{thread = [1, 8], workgroup = [8, 8]}>} dimension_map = [0] unique_indices(true) ins(%subview_0, %subview_1 : memref<1x8xf16, strided<[16, 1], offset: ?>>, memref<1xi32, strided<[1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<\[16, 1], offset: ?>>) {
        ^bb0(%arg7: f16, %arg8: f16):
          iree_linalg_ext.yield %arg7 : f16
        }
      } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
      linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) outs(%subview : memref<?x8xf16, strided<[16, 1], offset: ?>>) {
      ^bb0(%in: f16, %out: f16):
        linalg.yield %in : f16
      }
    } {mapping = [#iree_codegen.workgroup_mapping<x>]}
 linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc : memref<?x16xf16>) outs(%arg3 : memref<?x16xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    }
 return
  }
}
```

(Ignore the outer `alloca` that I dont know how to avoid and is just an artifiact of the test). The inner-alloca created in the first example is completely unnecessary. Bufferization has to be able to deal with the slice not being there.

@matthias-springer  can you suggest how I can go about fixing this issue? 
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to