Issue |
123375
|
Summary |
[mlir] BlockEquivalenceData is wrong?
|
Labels |
mlir
|
Assignees |
jpienaar
|
Reporter |
makslevental
|
```mlir
tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr<f32>>), ^bb2(%6 : tensor<1024x!tt.ptr<f32>>)
^bb1(%7: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
^bb2(%9: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>>
tt.return %10 : tensor<1024xf32>
}
```
`mlir::simplifyRegions` gives
```mlir
tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
%c0_i64 = arith.constant 0 : i64
%0 = builtin.unrealized_conversion_cast %arg0, %c0_i64 : !tt.ptr<f32>, i64 to !tt.ptr<f32>
%c1024_i32 = arith.constant 1024 : i32
%1 = tt.get_program_id x : i32
%2 = arith.muli %1, %c1024_i32 : i32
%3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%4 = tt.splat %2 : i32 -> tensor<1024xi32>
%5 = arith.addi %4, %3 : tensor<1024xi32>
%6 = tt.splat %0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%7 = tt.addptr %6, %5 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
cf.cond_br %arg1, ^bb1(%6 : tensor<1024x!tt.ptr<f32>>), ^bb1(%7 : tensor<1024x!tt.ptr<f32>>)
^bb1(%8: tensor<1024x!tt.ptr<f32>>): // 2 preds: ^bb0, ^bb0
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
tt.return %9 : tensor<1024xf32>
}
```
[because](https://github.com/llvm/llvm-project/blob/1e5f32e81f96af45551dafb369279c6d55ac9b97/mlir/lib/Transforms/Utils/RegionUtils.cpp#L491-L495)
> /// This class contains the information for comparing the equivalencies of two
> /// blocks. Blocks are considered equivalent if they contain the same operations
> /// in the same order. The only allowed divergence is for operands that come
> /// from sources outside of the parent block, i.e. the uses of values produced
> /// within the block must be equivalent.
I don't understand how that's a legal merge/rewrite/change?
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs