https://github.com/ergawy created 
https://github.com/llvm/llvm-project/pull/156589

Extends support for mapping `do concurrent` on the device by adding support for 
`local` specifiers. The changes in this PR map the local variable to the 
`omp.target` op and uses the mapped value as the `private` clause operand in 
the nested `omp.parallel` op.

>From 78fc5ed0cbe7211bf89c744b0e8301bfe1722295 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.erg...@amd.com>
Date: Tue, 2 Sep 2025 05:54:00 -0500
Subject: [PATCH] [flang][OpenMP] `do concurrent`: support `local` on device

Extends support for mapping `do concurrent` on the device by adding
support for `local` specifiers. The changes in this PR map the local
variable to the `omp.target` op and uses the mapped value as the
`private` clause operand in the nested `omp.parallel` op.
---
 .../include/flang/Optimizer/Dialect/FIROps.td |  12 ++
 .../OpenMP/DoConcurrentConversion.cpp         | 192 +++++++++++-------
 .../Transforms/DoConcurrent/local_device.mlir |  49 +++++
 3 files changed, 175 insertions(+), 78 deletions(-)
 create mode 100644 flang/test/Transforms/DoConcurrent/local_device.mlir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td 
b/flang/include/flang/Optimizer/Dialect/FIROps.td
index bc971e8fd6600..fc6eedc6ed4c6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
       return getReduceVars().size();
     }
 
+    unsigned getInductionVarsStart() {
+      return 0;
+    }
+
+    unsigned getLocalOperandsStart() {
+      return getNumInductionVars();
+    }
+
+    unsigned getReduceOperandsStart() {
+      return getLocalOperandsStart() + getNumLocalOperands();
+    }
+
     mlir::Block::BlockArgListType getInductionVars() {
       return getBody()->getArguments().slice(0, getNumInductionVars());
     }
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp 
b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index a800a20129abe..66b778fecc208 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
 
         liveIns.push_back(operand->get());
       });
+
+  for (mlir::Value local : loop.getLocalVars())
+    liveIns.push_back(local);
 }
 
 /// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
               .getIsTargetDevice();
 
       mlir::omp::TargetOperands targetClauseOps;
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps,
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
                            isTargetDevice ? nullptr : &targetClauseOps);
 
       LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
     }
 
     mlir::omp::ParallelOp parallelOp =
-        genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
+        genParallelOp(rewriter, loop, ivInfos, mapper);
 
     // Only set as composite when part of `distribute parallel do`.
     parallelOp.setComposite(mapToDevice);
 
     if (!mapToDevice)
-      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
-                           loopNestClauseOps);
+      genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
 
     for (mlir::Value local : locals)
       looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
     if (mapToDevice)
       genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
 
-    mlir::omp::LoopNestOp ompLoopNest =
+    auto [loopNestOp, wsLoopOp] =
         genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
                     /*isComposite=*/mapToDevice);
 
+    // `local` region arguments are transferred/cloned from the `do concurrent`
+    // loop to the loopnest op when the region is cloned above. Instead, these
+    // region arguments should be on the workshare loop's region.
+    if (mapToDevice) {
+      for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
+               parallelOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
+
+      for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
+               wsLoopOp.getRegion().getArguments(),
+               loopNestOp.getRegion().getArguments().slice(
+                   loop.getReduceOperandsStart(), 
loop.getNumReduceOperands())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    } else {
+      for (auto [wsloopArg, loopNestArg] :
+           llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
+                           loopNestOp.getRegion().getArguments().drop_front(
+                               loopNestClauseOps.loopLowerBounds.size())))
+        rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
+    }
+
+    for (unsigned i = 0;
+         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
+      loopNestOp.getRegion().eraseArgument(
+          loopNestClauseOps.loopLowerBounds.size());
+
     rewriter.setInsertionPoint(doLoop);
     fir::FirOpBuilder builder(
         rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
     // Mark `unordered` loops that are not perfectly nested to be skipped from
     // the legality check of the `ConversionTarget` since we are not interested
     // in mapping them to OpenMP.
-    ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
+    loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
       concurrentLoopsToSkip.insert(doLoop);
     });
 
@@ -370,11 +399,21 @@ class DoConcurrentConversion
       llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
 
   mlir::omp::ParallelOp
-  genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
+  genParallelOp(mlir::ConversionPatternRewriter &rewriter,
+                fir::DoConcurrentLoopOp loop,
                 looputils::InductionVariableInfos &ivInfos,
                 mlir::IRMapping &mapper) const {
-    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
-    rewriter.createBlock(&parallelOp.getRegion());
+    mlir::omp::ParallelOperands parallelOps;
+
+    if (mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, parallelOps);
+
+    mlir::Location loc = loop.getLoc();
+    auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, 
parallelOps);
+    Fortran::common::openmp::EntryBlockArgs parallelArgs;
+    parallelArgs.priv.vars = parallelOps.privateVars;
+    Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
+                                           parallelOp.getRegion());
     rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
 
     genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
 
   void genLoopNestClauseOps(
       mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
-      fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
+      fir::DoConcurrentLoopOp loop,
       mlir::omp::LoopNestOperands &loopNestClauseOps,
       mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
     assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
     loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
   }
 
-  mlir::omp::LoopNestOp
+  std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
   genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
               fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
               const mlir::omp::LoopNestOperands &clauseOps,
               bool isComposite) const {
     mlir::omp::WsloopOperands wsloopClauseOps;
-
-    auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
-                                           mlir::Region &ompRegion) {
-      if (!firRegion.empty()) {
-        rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
-        auto firYield =
-            mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
-        rewriter.setInsertionPoint(firYield);
-        mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
-                                   firYield.getOperands());
-        rewriter.eraseOp(firYield);
-      }
-    };
-
-    // For `local` (and `local_init`) opernads, emit corresponding `private`
-    // clauses and attach these clauses to the workshare loop.
-    if (!loop.getLocalVars().empty())
-      for (auto [op, sym, arg] : llvm::zip_equal(
-               loop.getLocalVars(),
-               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
-               loop.getRegionLocalArgs())) {
-        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
-            sym.getLeafReference());
-        if (localizer.getLocalitySpecifierType() ==
-            fir::LocalitySpecifierType::LocalInit)
-          TODO(localizer.getLoc(),
-               "local_init conversion is not supported yet");
-
-        mlir::OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.setInsertionPointAfter(localizer);
-
-        auto privatizer = mlir::omp::PrivateClauseOp::create(
-            rewriter, localizer.getLoc(), sym.getLeafReference().str() + 
".omp",
-            localizer.getTypeAttr().getValue(),
-            mlir::omp::DataSharingClauseType::Private);
-
-        cloneFIRRegionToOMP(localizer.getInitRegion(),
-                            privatizer.getInitRegion());
-        cloneFIRRegionToOMP(localizer.getDeallocRegion(),
-                            privatizer.getDeallocRegion());
-
-        moduleSymbolTable.insert(privatizer);
-
-        wsloopClauseOps.privateVars.push_back(op);
-        wsloopClauseOps.privateSyms.push_back(
-            mlir::SymbolRefAttr::get(privatizer));
-      }
+    if (!mapToDevice)
+      genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
 
     if (!loop.getReduceVars().empty()) {
       for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -515,15 +509,15 @@ class DoConcurrentConversion
               rewriter, firReducer.getLoc(), ompReducerName,
               firReducer.getTypeAttr().getValue());
 
-          cloneFIRRegionToOMP(firReducer.getAllocRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
                               ompReducer.getAllocRegion());
-          cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
                               ompReducer.getInitializerRegion());
-          cloneFIRRegionToOMP(firReducer.getReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
                               ompReducer.getReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
                               ompReducer.getAtomicReductionRegion());
-          cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
+          cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
                               ompReducer.getCleanupRegion());
           moduleSymbolTable.insert(ompReducer);
         }
@@ -555,21 +549,10 @@ class DoConcurrentConversion
 
     rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
     mlir::omp::YieldOp::create(rewriter, loop->getLoc());
+    loop->getParentOfType<mlir::ModuleOp>().print(
+        llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
 
-    // `local` region arguments are transferred/cloned from the `do concurrent`
-    // loop to the loopnest op when the region is cloned above. Instead, these
-    // region arguments should be on the workshare loop's region.
-    for (auto [wsloopArg, loopNestArg] :
-         llvm::zip_equal(wsloopOp.getRegion().getArguments(),
-                         loopNestOp.getRegion().getArguments().drop_front(
-                             clauseOps.loopLowerBounds.size())))
-      rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
-
-    for (unsigned i = 0;
-         i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
-      loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
-
-    return loopNestOp;
+    return {loopNestOp, wsloopOp};
   }
 
   void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
     return distOp;
   }
 
+  void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
+                           mlir::Region &firRegion,
+                           mlir::Region &ompRegion) const {
+    if (!firRegion.empty()) {
+      rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
+      auto firYield =
+          mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
+      rewriter.setInsertionPoint(firYield);
+      mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
+                                 firYield.getOperands());
+      rewriter.eraseOp(firYield);
+    }
+  }
+
+  void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
+                      mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+                      mlir::omp::PrivateClauseOps &privateClauseOps) const {
+    // For `local` (and `local_init`) operands, emit corresponding `private`
+    // clauses and attach these clauses to the workshare loop.
+    if (!loop.getLocalVars().empty())
+      for (auto [var, sym, arg] : llvm::zip_equal(
+               loop.getLocalVars(),
+               loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+               loop.getRegionLocalArgs())) {
+        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
+            sym.getLeafReference());
+        if (localizer.getLocalitySpecifierType() ==
+            fir::LocalitySpecifierType::LocalInit)
+          TODO(localizer.getLoc(),
+               "local_init conversion is not supported yet");
+
+        mlir::OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointAfter(localizer);
+
+        auto privatizer = mlir::omp::PrivateClauseOp::create(
+            rewriter, localizer.getLoc(), sym.getLeafReference().str() + 
".omp",
+            localizer.getTypeAttr().getValue(),
+            mlir::omp::DataSharingClauseType::Private);
+
+        cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
+                            privatizer.getInitRegion());
+        cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
+                            privatizer.getDeallocRegion());
+
+        moduleSymbolTable.insert(privatizer);
+
+        privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
+                                                           : var);
+        privateClauseOps.privateSyms.push_back(
+            mlir::SymbolRefAttr::get(privatizer));
+      }
+  }
+
   bool mapToDevice;
   llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
   mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/local_device.mlir 
b/flang/test/Transforms/DoConcurrent/local_device.mlir
new file mode 100644
index 0000000000000..e54bb1aeb414e
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/local_device.mlir
@@ -0,0 +1,49 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | 
FileCheck %s
+
+fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
+
+func.func @_QPfoo() {
+  %0 = fir.dummy_scope : !fir.dscope
+  %3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> 
(!fir.ref<f32>, !fir.ref<f32>)
+
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  fir.do_concurrent {
+    %7 = fir.alloca i32 {bindc_name = "i"}
+    %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> 
(!fir.ref<i32>, !fir.ref<i32>)
+
+    fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) 
local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
+      %9 = fir.convert %arg0 : (index) -> i32
+      fir.store %9 to %8#0 : !fir.ref<i32>
+      %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : 
(!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+      %cst = arith.constant 4.200000e+01 : f32
+      hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
+    }
+  }
+  return
+}
+
+// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
+
+// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = 
"{{.*}}my_local"}
+// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> 
%[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
+// CHECK:   %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] 
{uniq_name = "_QFfooEmy_local"}
+
+// CHECK:   omp.teams {
+// CHECK:     omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 
-> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
+// CHECK:       omp.distribute {
+// CHECK:         omp.wsloop {
+// CHECK:           omp.loop_nest {{.*}} {
+// CHECK:             %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare 
%[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
+// CHECK:             hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
+// CHECK:             omp.yield
+// CHECK:           }
+// CHECK:         }
+// CHECK:       }
+// CHECK:     }
+// CHECK:   }
+// CHECK: }

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to