This revision was automatically updated to reflect the committed changes.
Closed by commit rG614958912784: [OMPIRBuilder] Support depend clause for task 
(authored by psoni2628, committed by Prabhdeep Singh Soni (A) 
Herald added a subscriber: cfe-commits.

  rG LLVM Github Monorepo



Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
--- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5092,6 +5092,81 @@
   EXPECT_FALSE(verifyModule(*M, &errs()));
+TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+  BasicBlock *AllocaBB = Builder.GetInsertBlock();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  OpenMPIRBuilder::LocationDescription Loc(
+      InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL);
+  AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext()));
+  OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn,
+                                   Type::getInt32Ty(M->getContext()), InDep);
+  SmallVector<OpenMPIRBuilder::DependData *, 4> DDS;
+  DDS.push_back(&DDIn);
+  Builder.restoreIP(OMPBuilder.createTask(
+      Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB,
+      /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+  // Check for the `NumDeps` argument
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder
+          .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps)
+          ->user_back());
+  ASSERT_NE(TaskAllocCall, nullptr);
+  ConstantInt *NumDeps = dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
+  ASSERT_NE(NumDeps, nullptr);
+  EXPECT_EQ(NumDeps->getZExtValue(), 1U);
+  // Check for the `DepInfo` array argument
+  BitCastInst *DepArrayPtr =
+      dyn_cast<BitCastInst>(TaskAllocCall->getOperand(4));
+  ASSERT_NE(DepArrayPtr, nullptr);
+  AllocaInst *DepArray = dyn_cast<AllocaInst>(DepArrayPtr->getOperand(0));
+  ASSERT_NE(DepArray, nullptr);
+  Value::user_iterator DepArrayI = DepArray->user_begin();
+  EXPECT_EQ(*DepArrayI, DepArrayPtr);
+  ++DepArrayI;
+  Value::user_iterator DepInfoI = DepArrayI->user_begin();
+  // Check for the `DependKind` flag in the `DepInfo` array
+  Value *Flag = findStoredValue<GetElementPtrInst>(*DepInfoI);
+  ASSERT_NE(Flag, nullptr);
+  ConstantInt *FlagInt = dyn_cast<ConstantInt>(Flag);
+  ASSERT_NE(FlagInt, nullptr);
+  EXPECT_EQ(FlagInt->getZExtValue(),
+            static_cast<unsigned int>(RTLDependenceKindTy::DepIn));
+  ++DepInfoI;
+  // Check for the size in the `DepInfo` array
+  Value *Size = findStoredValue<GetElementPtrInst>(*DepInfoI);
+  ASSERT_NE(Size, nullptr);
+  ConstantInt *SizeInt = dyn_cast<ConstantInt>(Size);
+  ASSERT_NE(SizeInt, nullptr);
+  EXPECT_EQ(SizeInt->getZExtValue(), 4U);
+  ++DepInfoI;
+  // Check for the variable address in the `DepInfo` array
+  Value *AddrStored = findStoredValue<GetElementPtrInst>(*DepInfoI);
+  ASSERT_NE(AddrStored, nullptr);
+  PtrToIntInst *AddrInt = dyn_cast<PtrToIntInst>(AddrStored);
+  ASSERT_NE(AddrInt, nullptr);
+  Value *Addr = AddrInt->getPointerOperand();
+  EXPECT_EQ(Addr, InDep);
+  ConstantInt *NumDepsNoAlias =
+      dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(5));
+  ASSERT_NE(NumDepsNoAlias, nullptr);
+  EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U);
+  EXPECT_EQ(TaskAllocCall->getOperand(6),
+            ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext())));
+  EXPECT_FALSE(verifyModule(*M, &errs()));
 TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1290,7 +1290,8 @@
 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                            bool Tied, Value *Final, Value *IfCondition) {
+                            bool Tied, Value *Final, Value *IfCondition,
+                            ArrayRef<DependData *> Dependencies) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -1322,8 +1323,8 @@
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, Ident, Tied, Final,
-                      IfCondition](Function &OutlinedFn) {
+  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
+                      Dependencies](Function &OutlinedFn) {
     // The input IR here looks like the following-
     // ```
     // func @current_fn() {
@@ -1433,6 +1434,49 @@
+    Value *DepArrayPtr = nullptr;
+    if (Dependencies.size()) {
+      InsertPointTy OldIP = Builder.saveIP();
+      Builder.SetInsertPoint(
+          &OldIP.getBlock()->getParent()->getEntryBlock().back());
+      Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
+      Value *DepArray =
+          Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
+      unsigned P = 0;
+      for (DependData *Dep : Dependencies) {
+        Value *Base =
+            Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
+        // Store the pointer to the variable
+        Value *Addr = Builder.CreateStructGEP(
+            DependInfo, Base,
+            static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
+        Value *DepValPtr =
+            Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty());
+        Builder.CreateStore(DepValPtr, Addr);
+        // Store the size of the variable
+        Value *Size = Builder.CreateStructGEP(
+            DependInfo, Base,
+            static_cast<unsigned int>(RTLDependInfoFields::Len));
+        Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
+                                Dep->DepValueType)),
+                            Size);
+        // Store the dependency kind
+        Value *Flags = Builder.CreateStructGEP(
+            DependInfo, Base,
+            static_cast<unsigned int>(RTLDependInfoFields::Flags));
+        Builder.CreateStore(
+            ConstantInt::get(Builder.getInt8Ty(),
+                             static_cast<unsigned int>(Dep->DepKind)),
+            Flags);
+        ++P;
+      }
+      DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy());
+      Builder.restoreIP(OldIP);
+    }
     // In the presence of the `if` clause, the following IR is generated:
     //    ...
     //    %data = call @__kmpc_omp_task_alloc(...)
@@ -1471,9 +1515,21 @@
       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
-    // Emit the @__kmpc_omp_task runtime call to spawn the task
-    Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
-    Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+    if (Dependencies.size()) {
+      Function *TaskFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
+      Builder.CreateCall(
+          TaskFn,
+          {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
+           DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
+           ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
+    } else {
+      // Emit the @__kmpc_omp_task runtime call to spawn the task
+      Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
+    }
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -92,6 +92,7 @@
 __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr,
                   VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64)
 __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr)
+__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8)
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -645,6 +645,17 @@
   /// \param Loc The location where the taskyield directive was encountered.
   void createTaskyield(const LocationDescription &Loc);
+  /// A struct to pack the relevant information for an OpenMP depend clause.
+  struct DependData {
+    omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown;
+    Type *DepValueType;
+    Value *DepVal;
+    explicit DependData() = default;
+    DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType,
+               Value *DepVal)
+        : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {}
+  };
   /// Generator for `#omp task`
   /// \param Loc The location where the task construct was encountered.
@@ -662,7 +673,8 @@
   InsertPointTy createTask(const LocationDescription &Loc,
                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
                            bool Tied = true, Value *Final = nullptr,
-                           Value *IfCondition = nullptr);
+                           Value *IfCondition = nullptr,
+                           ArrayRef<DependData *> Dependencies = {});
   /// Generator for the taskgroup construct
Index: llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
--- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -207,6 +207,19 @@
 /// Atomic compare operations. Currently OpenMP only supports ==, >, and <.
 enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX };
+/// Fields ids in kmp_depend_info record.
+enum class RTLDependInfoFields { BaseAddr, Len, Flags };
+/// Dependence kind for RTL.
+enum class RTLDependenceKindTy {
+  DepUnknown = 0x0,
+  DepIn = 0x01,
+  DepInOut = 0x3,
+  DepMutexInOutSet = 0x4,
+  DepInOutSet = 0x8,
+  DepOmpAllMem = 0x80,
 } // end namespace omp
 } // end namespace llvm
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -4377,39 +4377,26 @@
   return Result;
-namespace {
-/// Dependence kind for RTL.
-enum RTLDependenceKindTy {
-  DepIn = 0x01,
-  DepInOut = 0x3,
-  DepMutexInOutSet = 0x4,
-  DepInOutSet = 0x8,
-  DepOmpAllMem = 0x80,
-/// Fields ids in kmp_depend_info record.
-enum RTLDependInfoFieldsTy { BaseAddr, Len, Flags };
-} // namespace
 /// Translates internal dependency kind into the runtime kind.
 static RTLDependenceKindTy translateDependencyKind(OpenMPDependClauseKind K) {
   RTLDependenceKindTy DepKind;
   switch (K) {
   case OMPC_DEPEND_in:
-    DepKind = DepIn;
+    DepKind = RTLDependenceKindTy::DepIn;
   // Out and InOut dependencies must use the same code.
   case OMPC_DEPEND_out:
   case OMPC_DEPEND_inout:
-    DepKind = DepInOut;
+    DepKind = RTLDependenceKindTy::DepInOut;
   case OMPC_DEPEND_mutexinoutset:
-    DepKind = DepMutexInOutSet;
+    DepKind = RTLDependenceKindTy::DepMutexInOutSet;
   case OMPC_DEPEND_inoutset:
-    DepKind = DepInOutSet;
+    DepKind = RTLDependenceKindTy::DepInOutSet;
   case OMPC_DEPEND_outallmemory:
-    DepKind = DepOmpAllMem;
+    DepKind = RTLDependenceKindTy::DepOmpAllMem;
   case OMPC_DEPEND_source:
   case OMPC_DEPEND_sink:
@@ -4457,7 +4444,9 @@
       DepObjAddr, KmpDependInfoTy, Base.getBaseInfo(), Base.getTBAAInfo());
   // NumDeps = deps[i].base_addr;
   LValue BaseAddrLVal = CGF.EmitLValueForField(
-      NumDepsBase, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+      NumDepsBase,
+      *std::next(KmpDependInfoRD->field_begin(),
+                 static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
   llvm::Value *NumDeps = CGF.EmitLoadOfScalar(BaseAddrLVal, Loc);
   return std::make_pair(NumDeps, Base);
@@ -4503,18 +4492,24 @@
     // deps[i].base_addr = &<Dependencies[i].second>;
     LValue BaseAddrLVal = CGF.EmitLValueForField(
-        Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+        Base,
+        *std::next(KmpDependInfoRD->field_begin(),
+                   static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
     CGF.EmitStoreOfScalar(Addr, BaseAddrLVal);
     // deps[i].len = sizeof(<Dependencies[i].second>);
     LValue LenLVal = CGF.EmitLValueForField(
-        Base, *std::next(KmpDependInfoRD->field_begin(), Len));
+        Base, *std::next(KmpDependInfoRD->field_begin(),
+                         static_cast<unsigned int>(RTLDependInfoFields::Len)));
     CGF.EmitStoreOfScalar(Size, LenLVal);
     // deps[i].flags = <Dependencies[i].first>;
     RTLDependenceKindTy DepKind = translateDependencyKind(Data.DepKind);
     LValue FlagsLVal = CGF.EmitLValueForField(
-        Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
-    CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
-                          FlagsLVal);
+        Base,
+        *std::next(KmpDependInfoRD->field_begin(),
+                   static_cast<unsigned int>(RTLDependInfoFields::Flags)));
+    CGF.EmitStoreOfScalar(
+        llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
+        FlagsLVal);
     if (unsigned *P = Pos.dyn_cast<unsigned *>()) {
     } else {
@@ -4790,7 +4785,9 @@
   LValue Base = CGF.MakeAddrLValue(DependenciesArray, KmpDependInfoTy);
   // deps[i].base_addr = NumDependencies;
   LValue BaseAddrLVal = CGF.EmitLValueForField(
-      Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr));
+      Base,
+      *std::next(KmpDependInfoRD->field_begin(),
+                 static_cast<unsigned int>(RTLDependInfoFields::BaseAddr)));
   CGF.EmitStoreOfScalar(NumDepsVal, BaseAddrLVal);
   llvm::PointerUnion<unsigned *, LValue *> Pos;
   unsigned Idx = 1;
@@ -4870,9 +4867,11 @@
   // deps[i].flags = NewDepKind;
   RTLDependenceKindTy DepKind = translateDependencyKind(NewDepKind);
   LValue FlagsLVal = CGF.EmitLValueForField(
-      Base, *std::next(KmpDependInfoRD->field_begin(), Flags));
-  CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind),
-                        FlagsLVal);
+      Base, *std::next(KmpDependInfoRD->field_begin(),
+                       static_cast<unsigned int>(RTLDependInfoFields::Flags)));
+  CGF.EmitStoreOfScalar(
+      llvm::ConstantInt::get(LLVMFlagsTy, static_cast<unsigned int>(DepKind)),
+      FlagsLVal);
   // Shift the address forward by one element.
   Address ElementNext =
cfe-commits mailing list

Reply via email to