llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chaitanya (skc7)

<details>
<summary>Changes</summary>

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.

---
Full diff: https://github.com/llvm/llvm-project/pull/171767.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3) 
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7) 
- (modified) 
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td 
b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_num_dims, $num_threads_dims_values, 
type($num_threads_dims_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_num_dims` (dimension count) and 
`num_threads_dims_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : 
type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && 
getNumThreadsNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp 
b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public 
OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp 
b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState 
&state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, 
/*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, 
OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, 
clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
+
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, 
Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) 
{
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, 
llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.hasNumThreadsDimsModifier() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value 
&numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation 
*capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir 
b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error@+1 {{dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 
2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error@+1 {{dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error@+1 {{num_threads with dims modifier cannot be used 
together with number of threads}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 
2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error@+1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and 
privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, 
private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 
0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir 
b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, 
memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : 
i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = 
#omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, 
memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, 
%if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : 
memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

``````````

</details>


https://github.com/llvm/llvm-project/pull/171767
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to