Issue 123375
Summary [mlir] BlockEquivalenceData is wrong?
Labels mlir
Assignees jpienaar
Reporter makslevental
    ```mlir
 tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
 %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : i32 -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
 cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr<f32>>), ^bb2(%6 : tensor<1024x!tt.ptr<f32>>)
  ^bb1(%7: tensor<1024x!tt.ptr<f32>>):  // pred: ^bb0
    %8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
    tt.return %8 : tensor<1024xf32>
  ^bb2(%9: tensor<1024x!tt.ptr<f32>>):  // pred: ^bb0
 %10 = tt.load %9 : tensor<1024x!tt.ptr<f32>>
    tt.return %10 : tensor<1024xf32>
  }
```

`mlir::simplifyRegions` gives

```mlir
 tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
 %c0_i64 = arith.constant 0 : i64
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i64 : !tt.ptr<f32>, i64 to !tt.ptr<f32>
    %c1024_i32 = arith.constant 1024 : i32
    %1 = tt.get_program_id x : i32
    %2 = arith.muli %1, %c1024_i32 : i32
    %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %4 = tt.splat %2 : i32 -> tensor<1024xi32>
    %5 = arith.addi %4, %3 : tensor<1024xi32>
    %6 = tt.splat %0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
    %7 = tt.addptr %6, %5 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    cf.cond_br %arg1, ^bb1(%6 : tensor<1024x!tt.ptr<f32>>), ^bb1(%7 : tensor<1024x!tt.ptr<f32>>)
  ^bb1(%8: tensor<1024x!tt.ptr<f32>>):  // 2 preds: ^bb0, ^bb0
    %9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
    tt.return %9 : tensor<1024xf32>
 }
```

[because](https://github.com/llvm/llvm-project/blob/1e5f32e81f96af45551dafb369279c6d55ac9b97/mlir/lib/Transforms/Utils/RegionUtils.cpp#L491-L495)

> /// This class contains the information for comparing the equivalencies of two
> /// blocks. Blocks are considered equivalent if they contain the same operations
> /// in the same order. The only allowed divergence is for operands that come
> /// from sources outside of the parent block, i.e. the uses of values produced
> /// within the block must be equivalent.

I don't understand how that's a legal merge/rewrite/change? 


_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to