================
@@ -0,0 +1,259 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+  auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+  if (!workshare)
+    return false;
+  return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+  Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+  return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+
+  return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+///   %a = fir.allocmem
+///   omp.wsloop {}
+///   fir.call Assign %b %a
+///   fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+///   %a = fir.allocmem
+///   fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+///   fir.call Assign %b %a_reloaded
+///   fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which 
need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel 
so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+  assert(wsOp.getRegion().getBlocks().size() == 1);
+
+  Location loc = wsOp->getLoc();
+
+  omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+  if (!parallelOp) {
+    wsOp.emitWarning("cannot handle workshare, converting to single");
+    Operation *terminator = wsOp.getRegion().front().getTerminator();
+    wsOp->getBlock()->getOperations().splice(
+        wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+    terminator->erase();
+    return;
+  }
+
+  OpBuilder allocBuilder(parallelOp);
+  OpBuilder rootBuilder(wsOp);
+  IRMapping rootMapping;
+
+  omp::SingleOp singleOp = nullptr;
+
+  auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+                              IRMapping singleMapping) {
+    if (auto reloaded = rootMapping.lookupOrNull(v))
+      return;
+    Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+    Type ty = v.getType();
+    Value alloc, reloaded;
+    if (isSupportedByFirAlloca(ty)) {
+      alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+      singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+      reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+    } else {
+      auto one = allocBuilder.create<LLVM::ConstantOp>(
+          loc, allocBuilder.getI32Type(), 1);
+      alloc =
+          allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+      Value toStore = singleBuilder
+                          .create<UnrealizedConversionCastOp>(
+                              loc, llvmPtrTy, singleMapping.lookup(v))
+                          .getResult(0);
+      singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+      reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+      reloaded =
+          rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+              .getResult(0);
+    }
+    rootMapping.map(v, reloaded);
+  };
+
+  auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+    IRMapping singleMapping = rootMapping;
+
+    for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+      singleBuilder.clone(op, singleMapping);
+      if (isSafeToParallelize(&op)) {
+        rootBuilder.clone(op, rootMapping);
+      } else {
+        // Prepare reloaded values for results of operations that cannot be
+        // safely parallelized and which are used after the region `sr`
+        for (auto res : op.getResults()) {
+          for (auto &use : res.getUses()) {
+            Operation *user = use.getOwner();
+            while (user->getParentOp() != wsOp)
+              user = user->getParentOp();
+            if (!user->isBeforeInBlock(&*sr.end)) {
+              // We need to reload
+              mapReloadedValue(use.get(), singleBuilder, singleMapping);
+            }
+          }
+        }
+      }
+    }
+    singleBuilder.create<omp::TerminatorOp>(loc);
+  };
+
+  Block *wsBlock = &wsOp.getRegion().front();
+  assert(wsBlock->getTerminator()->getNumOperands() == 0);
+  Operation *terminator = wsBlock->getTerminator();
+
+  SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+  auto it = wsBlock->begin();
+  auto getSingleRegion = [&]() {
+    if (&*it == terminator)
+      return false;
+    if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+      regions.push_back(pop);
+      it++;
+      return true;
+    }
+    SingleRegion sr;
+    sr.begin = it;
+    while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+      it++;
+    sr.end = it;
+    assert(sr.begin != sr.end);
+    regions.push_back(sr);
+    return true;
+  };
+  while (getSingleRegion())
+    ;
+
+  for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+    bool isLast = i + 1 == regions.size();
+    if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+      omp::SingleOperands singleOperands;
+      if (isLast)
+        singleOperands.nowait = rootBuilder.getUnitAttr();
+      singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+      OpBuilder singleBuilder(singleOp);
+      singleBuilder.createBlock(&singleOp.getRegion());
+      moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+    } else {
+      rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+      if (!isLast)
+        rootBuilder.create<omp::BarrierOp>(loc);
+    }
+  }
+
+  if (!wsOp.getNowait())
+    rootBuilder.create<omp::BarrierOp>(loc);
----------------
ivanradanov wrote:

This specific barrier is for the synchronisation at the end of omp.workshare, 
which is mandated by the standard unless there is a nowait.

https://github.com/llvm/llvm-project/pull/101446
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to