Issue 95230
Summary Unexpected Behavior on Affine-Loop-Fusion
Labels new issue
Assignees
Reporter sgjzfzzf
    Hi, I'm developing a compiler for ONNX based on MLIR. I'm trying to optimize the code generation with the Affine passes, but I need help. Some error occurs in the LayerNormalization operator. Here is the code generated by my compiler automatically.

```mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
 func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
    %cst = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %c0_1 = arith.constant 0 : index
    %cst_2 = arith.constant 0.000000e+00 : f32
    %cst_3 = arith.constant 1.000000e+00 : f32
    %view_4 = memref.view %view[%c0_1][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst : memref<768xf32>
    %1 = bufferization.to_memref %cst_0 : memref<768xf32>
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg1 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %2 = arith.addf %in, %out : f32
      linalg.yield %2 : f32
    }
 linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %cst_5 = arith.constant 7.680000e+02 : f32
      %2 = arith.divf %in, %cst_5 : f32
      linalg.yield %2 : f32
    }
    linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %view_4 : memref<1x128x768xf32>, memref<1x128x1xf32>) outs(%arg0 : memref<1x128x768xf32>) {
    ^bb0(%in: f32, %in_5: f32, %out: f32):
      %2 = arith.subf %in, %in_5 : f32
 linalg.yield %2 : f32
    }
    linalg.fill ins(%cst_2 : f32) outs(%view_4 : memref<1x128x1xf32>)
    linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : memref<1x128x768xf32>) outs(%view_4 : memref<1x128x1xf32>) {
 ^bb0(%in: f32, %out: f32):
      %2 = arith.mulf %in, %in : f32
 %3 = arith.addf %out, %2 : f32
      linalg.yield %3 : f32
    }
 linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%view_4 : memref<1x128x1xf32>) outs(%view_4 : memref<1x128x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %cst_5 = arith.constant 7.680000e+02 : f32
      %2 = arith.divf %in, %cst_5 : f32
      linalg.yield %2 : f32
    }
    linalg.generic {indexing_maps = [#map, #map1, #map2, #map2, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %view_4, %0, %1 : memref<1x128x768xf32>, memref<1x128x1xf32>, memref<768xf32>, memref<768xf32>) outs(%arg0 : memref<1x128x768xf32>) {
    ^bb0(%in: f32, %in_5: f32, %in_6: f32, %in_7: f32, %out: f32):
      %cst_8 = arith.constant 9.99999996E-13 : f32
      %2 = arith.addf %in_5, %cst_8 : f32
      %3 = math.sqrt %2 : f32
      %4 = arith.divf %cst_3, %3 : f32
      %5 = arith.mulf %in, %4 : f32
      %6 = arith.mulf %5, %in_6 : f32
      %7 = arith.addf %6, %in_7 : f32
      linalg.yield %7 : f32
    }
 return
  }
}
```

Then, I use `mlir-opt-18 -convert-linalg-to-affine-loops <filename>` to lower it to the Affine dialect.

```mlir
module {
  func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %cst = arith.constant 9.99999996E-13 : f32
    %cst_0 = arith.constant 7.680000e+02 : f32
    %cst_1 = arith.constant 1.000000e+00 : f32
    %cst_2 = arith.constant 0.000000e+00 : f32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %cst_4 = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %c0 = arith.constant 0 : index
    %view = memref.view %arg2[%c0][] : memref<512xi8> to memref<512xi8>
    %view_5 = memref.view %view[%c0][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst_4 : memref<768xf32>
    %1 = bufferization.to_memref %cst_3 : memref<768xf32>
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = arith.addf %2, %3 : f32
 affine.store %4, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
 }
      }
    }
    affine.for %arg3 = 0 to 1 {
 affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
 %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
 %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
 }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg1[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
 %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
 %4 = arith.subf %2, %3 : f32
          affine.store %4, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
      }
    }
 affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
 affine.for %arg5 = 0 to 1 {
          affine.store %cst_2, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
 affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
 affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = arith.mulf %2, %2 : f32
          %5 = arith.addf %3, %4 : f32
 affine.store %5, %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
 }
      }
    }
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
 %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
 affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
 affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = affine.load %0[%arg5] : memref<768xf32>
          %5 = affine.load %1[%arg5] : memref<768xf32>
          %6 = arith.addf %3, %cst : f32
 %7 = math.sqrt %6 : f32
          %8 = arith.divf %cst_1, %7 : f32
          %9 = arith.mulf %2, %8 : f32
          %10 = arith.mulf %9, %4 : f32
          %11 = arith.addf %10, %5 : f32
 affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
 }
      }
    }
    return
  }
}
```

After that, I decided to use the Affine-Loop-Fusion pass to optimize it with `mlir-opt-18 -convert-linalg-to-affine-loops -affine-loop-fusion=mode=greedy <filename>`.

```mlir
module {
  func.func @layer_normalization(%arg0: memref<1x128x768xf32>, %arg1: memref<1x128x768xf32>, %arg2: memref<512xi8>) attributes {llvm.emit_c_interface} {
    %c0 = arith.constant 0 : index
    %c0_0 = arith.constant 0 : index
    %c0_1 = arith.constant 0 : index
 %c0_2 = arith.constant 0 : index
    %c0_3 = arith.constant 0 : index
 %c0_4 = arith.constant 0 : index
    %c0_5 = arith.constant 0 : index
    %c0_6 = arith.constant 0 : index
    %c0_7 = arith.constant 0 : index
    %cst = arith.constant 9.99999996E-13 : f32
    %cst_8 = arith.constant 7.680000e+02 : f32
    %cst_9 = arith.constant 1.000000e+00 : f32
    %cst_10 = arith.constant 0.000000e+00 : f32
    %cst_11 = arith.constant dense<0.000000e+00> : tensor<768xf32>
    %cst_12 = arith.constant dense<1.000000e+00> : tensor<768xf32>
    %c0_13 = arith.constant 0 : index
    %view = memref.view %arg2[%c0_13][] : memref<512xi8> to memref<512xi8>
    %view_14 = memref.view %view[%c0_13][] : memref<512xi8> to memref<1x128x1xf32>
    %0 = bufferization.to_memref %cst_12 : memref<768xf32>
    %1 = bufferization.to_memref %cst_11 : memref<768xf32>
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg1[%c0, %arg4, %arg5] : memref<1x128x768xf32>
          %5 = affine.load %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
          %6 = arith.addf %4, %5 : f32
 affine.store %6, %view_14[%c0, %arg4, %c0_13] : memref<1x128x1xf32>
        }
        %2 = affine.load %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
        %3 = arith.divf %2, %cst_8 : f32
        affine.store %3, %view_14[%c0_1, %arg4, %c0_0] : memref<1x128x1xf32>
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg1[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
 %5 = affine.load %view_14[%c0_2, %arg4, %c0_13] : memref<1x128x1xf32>
 %6 = arith.subf %4, %5 : f32
          affine.store %6, %arg0[%c0_2, %arg4, %arg5] : memref<1x128x768xf32>
        }
        affine.store %cst_10, %view_14[%c0_4, %arg4, %c0_3] : memref<1x128x1xf32>
 affine.for %arg5 = 0 to 768 {
          %4 = affine.load %arg0[%c0_5, %arg4, %arg5] : memref<1x128x768xf32>
          %5 = affine.load %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
          %6 = arith.mulf %4, %4 : f32
          %7 = arith.addf %5, %6 : f32
 affine.store %7, %view_14[%c0_5, %arg4, %c0_13] : memref<1x128x1xf32>
 }
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %5 = arith.divf %4, %cst_8 : f32
          affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
 %8 = affine.load %0[%arg5] : memref<768xf32>
          %9 = affine.load %1[%arg5] : memref<768xf32>
          %10 = arith.addf %7, %cst : f32
 %11 = math.sqrt %10 : f32
          %12 = arith.divf %cst_9, %11 : f32
          %13 = arith.mulf %6, %12 : f32
          %14 = arith.mulf %13, %8 : f32
          %15 = arith.addf %14, %9 : f32
 affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
 }
      }
    }
    return
  }
}
```

Please take a look at the last affine-for loop. The comparison follows:
```mlir
 // ...
    affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
        affine.for %arg5 = 0 to 1 {
          %2 = affine.load %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
          %3 = arith.divf %2, %cst_0 : f32
          affine.store %3, %view_5[%arg3, %arg4, %arg5] : memref<1x128x1xf32>
        }
      }
    }
 affine.for %arg3 = 0 to 1 {
      affine.for %arg4 = 0 to 128 {
 affine.for %arg5 = 0 to 768 {
          %2 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
          %3 = affine.load %view_5[%arg3, %arg4, %c0] : memref<1x128x1xf32>
          %4 = affine.load %0[%arg5] : memref<768xf32>
          %5 = affine.load %1[%arg5] : memref<768xf32>
          %6 = arith.addf %3, %cst : f32
 %7 = math.sqrt %6 : f32
          %8 = arith.divf %cst_1, %7 : f32
          %9 = arith.mulf %2, %8 : f32
          %10 = arith.mulf %9, %4 : f32
          %11 = arith.addf %10, %5 : f32
 affine.store %11, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
 }
      }
    }


    affine.for %arg3 = 0 to 1 {
 affine.for %arg4 = 0 to 128 {
        // ...
        affine.for %arg5 = 0 to 768 {
          %4 = affine.load %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
          %5 = arith.divf %4, %cst_8 : f32
 affine.store %5, %view_14[%c0_7, %arg4, %c0_6] : memref<1x128x1xf32>
 %6 = affine.load %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
 %7 = affine.load %view_14[%arg3, %arg4, %c0_13] : memref<1x128x1xf32>
          %8 = affine.load %0[%arg5] : memref<768xf32>
          %9 = affine.load %1[%arg5] : memref<768xf32>
 %10 = arith.addf %7, %cst : f32
          %11 = math.sqrt %10 : f32
          %12 = arith.divf %cst_9, %11 : f32
          %13 = arith.mulf %6, %12 : f32
          %14 = arith.mulf %13, %8 : f32
 %15 = arith.addf %14, %9 : f32
          affine.store %15, %arg0[%arg3, %arg4, %arg5] : memref<1x128x768xf32>
        }
 }
    }
```

The pass fuses the `divf` instruction into the loop error. In the `1-128` loop, the `divf` should be executed only once, but due to the wrong fusion, it will be executed 768 times instead. I also examined it in the real example, and the output changes after the pass, as we see in the code.

Could you please provide me with some information on this issue? Is it a bug, or is something wrong with my code and optimization? Thank you so much for your reading!
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to