Issue |
95230
|
Summary |
Unexpected Behavior on Affine-Loop-Fusion
|
Labels |
new issue
|
Assignees |
|
Reporter |
sgjzfzzf
|
Hi, I'm developing a compiler for ONNX based on MLIR. I'm trying to optimize the code generation with the Affine passes, but I need help. Some error occurs in the LayerNormalization operator. Here is the code generated by my compiler automatically.
```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%cst = arith.constant dense<1.000000e+00> : tensor<768xf32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%c0_1 = arith.constant 0 : index
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant 1.000000e+00 : f32
%view_4 = memref.view %view[%c0_1][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst : memref<768xf32>
%1 = bufferization.to_memref %cst_0 : memref<768xf32>
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %out : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %view_4 : memref<1x128x768xf32>, memref<1x128x1xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%2 = arith.subf %in, %in_5 : f32
linalg.yield %2 : f32
}
linalg.fill ins(%cst_2 : f32) outs(%view_4 : memref<1x128x1xf32>)
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.mulf %in, %in : f32
%3 = arith.addf %out, %2 : f32
linalg.yield %3 : f32
}
linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
^bb0(%in: f32, %out: f32):
%cst_5 = arith.constant 7.680000e+02 : f32
%2 = arith.divf %in, %cst_5 : f32
linalg.yield %2 : f32
}
linalg.generic {indexing_maps = [#map, #map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %view_4, %0, %1 : memref<1x128x768xf32>, memref<1x128x1xf32>, memref<768xf32>, memref<768xf32>) outs(%arg0 : memref<1x128x768xf32>) {
^bb0(%in: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
%cst_8 = arith.constant 9.99999996E-13 : f32
%2 = arith.addf %in_5, %cst_8 : f32
%3 = math.sqrt %2 : f32
%4 = arith.divf %cst_3, %3 : f32
%5 = arith.mulf %in, %4 : f32
%6 = arith.mulf %5, %in_6 : f32
%7 = arith.addf %6, %in_7 : f32
linalg.yield %7 : f32
}
return
}
}
```
Then, I use `mlir-opt-18 -convert-linalg-to-affine-loops <filename>` to lower it to the Affine dialect.
```mlir
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%cst = arith.constant 9.99999996E-13 : f32
%cst_0 = arith.constant 7.680000e+02 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant 0.000000e+00 : f32
%cst_3 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_4 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0 = arith.constant 0 : index
%view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
%view_5 = memref.view %view[%c0][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_4 : memref<768xf32>
%1 = bufferization.to_memref %cst_3 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.addf %2, %3 : f32
affine.store %4, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.subf %2, %3 : f32
affine.store %4, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
affine.store %cst_2, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = arith.mulf %2, %2 : f32
%5 = arith.addf %3, %4 : f32
affine.store %5, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}
```
After that, I decided to use the Affine-Loop-Fusion pass to optimize it with `mlir-opt-18 -convert-linalg-to-affine-loops -affine-loop-fusion=mode=greedy <filename>`.
```mlir
module {
func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c0_0 = arith.constant 0 : index
%c0_1 = arith.constant 0 : index
%c0_2 = arith.constant 0 : index
%c0_3 = arith.constant 0 : index
%c0_4 = arith.constant 0 : index
%c0_5 = arith.constant 0 : index
%c0_6 = arith.constant 0 : index
%c0_7 = arith.constant 0 : index
%cst = arith.constant 9.99999996E-13 : f32
%cst_8 = arith.constant 7.680000e+02 : f32
%cst_9 = arith.constant 1.000000e+00 : f32
%cst_10 = arith.constant 0.000000e+00 : f32
%cst_11 = arith.constant dense<0.000000e+00> : tensor<768xf32>
%cst_12 = arith.constant dense<1.000000e+00> : tensor<768xf32>
%c0_13 = arith.constant 0 : index
%view = memref.view %arg2[%c0_13][] : memref<512xi8> to memref<512xi8>
%view_14 = memref.view %view[%c0_13][] : memref<512xi8> to memref<1x128x1xf32>
%0 = bufferization.to_memref %cst_12 : memref<768xf32>
%1 = bufferization.to_memref %cst_11 : memref<768xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.addf %4, %5 : f32
affine.store %6, %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
}
%2 = affine.load %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_8 : f32
affine.store %3, %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg1[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_2, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.subf %4, %5 : f32
affine.store %6, %arg0[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
}
affine.store %cst_10, %view_14[%c0_4, %arg4, %c0_3] : memref<1x128x1xf32>
affine.for %arg5 = 0 to 768 {
%4 = affine.load %arg0[%c0_5, %arg4, %arg5] : memref<1x128x768xf32>
%5 = affine.load %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
%6 = arith.mulf %4, %4 : f32
%7 = arith.addf %5, %6 : f32
affine.store %7, %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
}
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
return
}
}
```
Please take a look at the last affine-for loop. The comparison follows:
```mlir
// ...
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 1 {
%2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
%3 = arith.divf %2, %cst_0 : f32
affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
affine.for %arg5 = 0 to 768 {
%2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
%4 = affine.load %0[%arg5] : memref<768xf32>
%5 = affine.load %1[%arg5] : memref<768xf32>
%6 = arith.addf %3, %cst : f32
%7 = math.sqrt %6 : f32
%8 = arith.divf %cst_1, %7 : f32
%9 = arith.mulf %2, %8 : f32
%10 = arith.mulf %9, %4 : f32
%11 = arith.addf %10, %5 : f32
affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 128 {
// ...
affine.for %arg5 = 0 to 768 {
%4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%5 = arith.divf %4, %cst_8 : f32
affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
%6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
%7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
%8 = affine.load %0[%arg5] : memref<768xf32>
%9 = affine.load %1[%arg5] : memref<768xf32>
%10 = arith.addf %7, %cst : f32
%11 = math.sqrt %10 : f32
%12 = arith.divf %cst_9, %11 : f32
%13 = arith.mulf %6, %12 : f32
%14 = arith.mulf %13, %8 : f32
%15 = arith.addf %14, %9 : f32
affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
}
}
}
```
The pass fuses the `divf` instruction into the loop error. In the `1-128` loop, the `divf` should be executed only once, but due to the wrong fusion, it will be executed 768 times instead. I also examined it in the real example, and the output changes after the pass, as we see in the code.
Could you please provide me with some information on this issue? Is it a bug, or is something wrong with my code and optimization? Thank you so much for your reading!
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs