Issue 138595
Summary [MLIR] convert-linalg-to-parallel-loops doesn't seem to parallelize reductions
Labels
Assignees
Reporter brian-kelley
    I have a project that uses MLIR to lower linalg/tensor code to scf/memref. Ideally the resulting code exposes all possible parallelism from the input operations. One of the passes I'm using is ``--convert-linalg-to-parallel-loops``, but it falls back to a sequential ``scf.for`` loop for reduction dimensions, instead of ``scf.parallel/scf.reduce``.

Below are a couple of simple examples (dot and matmul) that turn the reduction dimension into ``scf.for`` loops. To replicate, run them through ``mlir-opt --convert-linalg-to-parallel-loops``.

Is there another way to generate parallel loops from linalg ops that will keep reductions parallel? One alternative I tried is ``--convert-linalg-to-affine-loops --affine-parallelize``, but this also parallelizes over the non-reduction dimensions only (and it seems like the affine dialect can't express a parallel reduction anyway).

dot.mlir:
```
module {
func.func @dot(%arg0: memref<?xf64>, %arg1: memref<?xf64>) -> f64 {
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<f64>
  linalg.dot ins(%arg0, %arg1 : memref<?xf64>, memref<?xf64>) outs(%alloc : memref<f64>)
  %0 = memref.load %alloc[] : memref<f64>
  return %0 : f64
}
}
```
result:
```
module {
  func.func @dot(%arg0: memref<?xf64>, %arg1: memref<?xf64>) -> f64 {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<f64>
    %dim = memref.dim %arg0, %c0 : memref<?xf64>
    scf.for %arg2 = %c0 to %dim step %c1 {
      %1 = memref.load %arg0[%arg2] : memref<?xf64>
      %2 = memref.load %arg1[%arg2] : memref<?xf64>
      %3 = memref.load %alloc[] : memref<f64>
      %4 = arith.mulf %1, %2 : f64
      %5 = arith.addf %3, %4 : f64
 memref.store %5, %alloc[] : memref<f64>
    }
    %0 = memref.load %alloc[] : memref<f64>
    return %0 : f64
  }
}
```

matmul.mlir:
```
module {
func.func @dot(%arg0: memref<64x64xf64>, %arg1: memref<64x64xf64>) -> memref<64x64xf64> {
  %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x64xf64>
  linalg.matmul ins(%arg0, %arg1 : memref<64x64xf64>, memref<64x64xf64>) outs(%alloc : memref<64x64xf64>)
  return %alloc : memref<64x64xf64>
}
}
```
result:
```
module {
  func.func @dot(%arg0: memref<64x64xf64>, %arg1: memref<64x64xf64>) -> memref<64x64xf64> {
    %c1 = arith.constant 1 : index
    %c64 = arith.constant 64 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x64xf64>
    scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c64, %c64) step (%c1, %c1) {
      scf.for %arg4 = %c0 to %c64 step %c1 {
 %0 = memref.load %arg0[%arg2, %arg4] : memref<64x64xf64>
        %1 = memref.load %arg1[%arg4, %arg3] : memref<64x64xf64>
        %2 = memref.load %alloc[%arg2, %arg3] : memref<64x64xf64>
        %3 = arith.mulf %0, %1 : f64
        %4 = arith.addf %2, %3 : f64
        memref.store %4, %alloc[%arg2, %arg3] : memref<64x64xf64>
      }
      scf.reduce 
    }
 return %alloc : memref<64x64xf64>
  }
}
```
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to