https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 1c69d29651bb1b73c04cca422454eb7ffffd7c4c Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 11:56:58 +0530 Subject: [PATCH 1/3] [OpenMP][MLIR] Add num_threads clause with dims modifier support --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 + mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++-- mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++- mlir/test/Dialect/OpenMP/ops.mlir | 15 ++-- 5 files changed, 163 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index e36dc7c246f01..7525b6e4e99f6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, + Variadic<AnyInteger>:$num_threads_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ - `num_threads` `(` $num_threads `:` type($num_threads) `)` + `num_threads` `(` custom<NumThreadsClause>( + $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads, type($num_threads) + ) `)` }]; let description = [{ - The optional `num_threads` parameter specifies the number of threads which - should be used to execute the parallel region. + num_threads clause specifies the desired number of threads in the team + space formed by the construct on which it appears. + + With dims modifier: + - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Specifies upper bounds for each dimension (all must have same type) + - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` + - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` + + Without dims modifier: + - Uses `num_threads` + - If lower bound not specified, it defaults to upper bound value + - Format: `num_threads(bounds : type)` + - Example: `num_threads(%ub : i32)` + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasDimsModifier() { + return getNumThreadsDims().has_value(); + } + + /// Returns the number of dimensions specified by dims modifier + unsigned getNumDimensions() { + if (!hasDimsModifier()) + return 1; + return static_cast<unsigned>(*getNumThreadsDims()); + } + + /// Returns all dimension values as an operand range + ::mlir::OperandRange getDimensionValues() { + return getNumThreadsValues(); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumDimensions() + ::mlir::Value getDimensionValue(unsigned index) { + assert(index < getDimensionValues().size() && + "Dimension index out of bounds"); + return getDimensionValues()[index]; + } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 6423d49859c97..0d5333ec2e455 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, + /* num_threads_dims = */ nullptr, + /* num_threads_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d4dbf5f5244df..303ab94fbedff 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, + /*num_threads_dims=*/nullptr, + /*num_threads_values=*/ValueRange(), /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.procBindKind, - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build( + builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.numThreads, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2597,13 +2600,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + auto numThreadsDims = getNumThreadsDims(); + auto numThreadsValues = getNumThreadsValues(); + auto numThreads = getNumThreads(); + + // num_threads with dims modifier + if (numThreadsDims.has_value() && numThreadsValues.empty()) { + return emitError( + "num_threads dims modifier requires values to be specified"); + } + + if (numThreadsDims.has_value() && + numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { + return emitError("num_threads dims(") + << *numThreadsDims << ") specified but " << numThreadsValues.size() + << " values provided"; + } + + // num_threads dims and number of threads cannot be used together + if (numThreadsDims.has_value() && numThreads) { + return emitError( + "num_threads dims and number of threads cannot be used together"); + } + + // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); + // verify private variables restrictions if (failed(verifyPrivateVarList(*this))) return failure(); + // verify reduction variables restrictions return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4647,6 +4677,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } +//===----------------------------------------------------------------------===// +// Parser and printer for num_threads clause +//===----------------------------------------------------------------------===// +static ParseResult +parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, + SmallVectorImpl<Type> &types, + std::optional<OpAsmParser::UnresolvedOperand> &bounds, + Type &boundsType) { + if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { + return success(); + } + + OpAsmParser::UnresolvedOperand boundsOperand; + if (parser.parseOperand(boundsOperand) || parser.parseColon() || + parser.parseType(boundsType)) { + return failure(); + } + bounds = boundsOperand; + return success(); +} + +static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types, Value bounds, + Type boundsType) { + if (!values.empty()) { + printDimsModifierWithValues(p, dimsAttr, values, types); + } + if (bounds) { + p.printOperand(bounds); + p << " : " << boundsType; + } +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index dd367aba8da27..9e2e5722aab9f 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) { // ----- +func.func @num_threads_dims_no_values() { + // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + "omp.parallel"() ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + return +} + +// ----- + +func.func @num_threads_dims_mismatch(%n : i64) { + // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + omp.parallel num_threads(dims(2): %n : i64) { + omp.terminator + } + + return +} + +// ----- + +func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { + // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + "omp.parallel"(%n, %n, %m) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + return +} + +// ----- + func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} @@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 3633a4be1eb62..585c9483c08a9 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { @@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } + // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) + omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + omp.terminator + } + // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) { omp.terminator >From 6946aff41bb7f744d6445d0fc227fb7807ea2191 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 12:11:49 +0530 Subject: [PATCH 2/3] Mark mlir->llvmir translation for num_threads with dims as NYI --- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 00f782e87d5af..8d3d0ccb665bd 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2879,6 +2879,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; + // num_threads dims and values are not yet supported + assert(!opInst.getNumThreadsDims().has_value() && + opInst.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; @@ -5604,6 +5608,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; else @@ -5724,8 +5732,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, threadLimit = teamsOp.getThreadLimit(); } - if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) + if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { + // num_threads dims and values are not yet supported + assert(!parallelOp.getNumThreadsDims().has_value() && + parallelOp.getNumThreadsValues().empty() && + "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); + } } // Handle clauses impacting the number of teams. >From 33dcfd92bea8181da414b766101847338ee3b963 Mon Sep 17 00:00:00 2001 From: skc7 <[email protected]> Date: Thu, 11 Dec 2025 17:37:52 +0530 Subject: [PATCH 3/3] few more fixes --- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++-------- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++---------- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++-- mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++--- 5 files changed, 45 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 7525b6e4e99f6..09c1d4a8a5866 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins - ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims, - Variadic<AnyInteger>:$num_threads_values, + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, + Variadic<AnyInteger>:$num_threads_dims_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ `num_threads` `(` custom<NumThreadsClause>( - $num_threads_dims, $num_threads_values, type($num_threads_values), + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), $num_threads, type($num_threads) ) `)` }]; @@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip< space formed by the construct on which it appears. With dims modifier: - - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list) + - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) - Specifies upper bounds for each dimension (all must have same type) - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` @@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip< let extraClassDeclaration = [{ /// Returns true if the dims modifier is explicitly present - bool hasDimsModifier() { - return getNumThreadsDims().has_value(); + bool hasNumThreadsDimsModifier() { + return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); } /// Returns the number of dimensions specified by dims modifier - unsigned getNumDimensions() { - if (!hasDimsModifier()) + unsigned getNumThreadsDimsCount() { + if (!hasNumThreadsDimsModifier()) return 1; - return static_cast<unsigned>(*getNumThreadsDims()); - } - - /// Returns all dimension values as an operand range - ::mlir::OperandRange getDimensionValues() { - return getNumThreadsValues(); + return static_cast<unsigned>(*getNumThreadsNumDims()); } /// Returns the value for a specific dimension index - /// Index must be less than getNumDimensions() - ::mlir::Value getDimensionValue(unsigned index) { - assert(index < getDimensionValues().size() && - "Dimension index out of bounds"); - return getDimensionValues()[index]; + /// Index must be less than getNumThreadsDimsCount() + ::mlir::Value getNumThreadsDimsValue(unsigned index) { + assert(index < getNumThreadsDimsCount() && + "Num threads dims index out of bounds"); + return getNumThreadsDimsValues()[index]; } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 0d5333ec2e455..ab7bded7835be 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, - /* num_threads_dims = */ nullptr, - /* num_threads_values = */ llvm::SmallVector<Value>{}, + /* num_threads_num_dims = */ nullptr, + /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 303ab94fbedff..a9ed0274cd21c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2548,7 +2548,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); ParallelOp::build( builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues, + clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, clauses.numThreads, clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, @@ -2599,30 +2599,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } -LogicalResult ParallelOp::verify() { - // verify num_threads clause restrictions - auto numThreadsDims = getNumThreadsDims(); - auto numThreadsValues = getNumThreadsValues(); - auto numThreads = getNumThreads(); - - // num_threads with dims modifier - if (numThreadsDims.has_value() && numThreadsValues.empty()) { - return emitError( - "num_threads dims modifier requires values to be specified"); - } - - if (numThreadsDims.has_value() && - numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) { - return emitError("num_threads dims(") - << *numThreadsDims << ") specified but " << numThreadsValues.size() - << " values provided"; +// Helper: Verify num_threads clause +LogicalResult +verifyNumThreadsClause(Operation *op, + std::optional<IntegerAttr> numThreadsNumDims, + OperandRange numThreadsDimsValues, Value numThreads) { + bool hasDimsModifier = + numThreadsNumDims.has_value() && numThreadsNumDims.value(); + if (hasDimsModifier && numThreads) { + return op->emitError("num_threads with dims modifier cannot be used " + "together with number of threads"); } + if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) + return failure(); + return success(); +} - // num_threads dims and number of threads cannot be used together - if (numThreadsDims.has_value() && numThreads) { - return emitError( - "num_threads dims and number of threads cannot be used together"); - } +LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + if (failed(verifyNumThreadsClause( + getOperation(), this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues(), this->getNumThreads()))) + return failure(); // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8d3d0ccb665bd..2bfb9fb2211c4 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2880,8 +2880,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; // num_threads dims and values are not yet supported - assert(!opInst.getNumThreadsDims().has_value() && - opInst.getNumThreadsValues().empty() && + assert(!opInst.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); @@ -5609,8 +5608,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, }) .Case([&](omp::ParallelOp parallelOp) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; @@ -5734,8 +5732,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { // num_threads dims and values are not yet supported - assert(!parallelOp.getNumThreadsDims().has_value() && - parallelOp.getNumThreadsValues().empty() && + assert(!parallelOp.hasNumThreadsDimsModifier() && "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 9e2e5722aab9f..db0ddcb415d42 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) { // ----- func.func @num_threads_dims_no_values() { - // expected-error@+1 {{num_threads dims modifier requires values to be specified}} + // expected-error@+1 {{dims modifier requires values to be specified}} "omp.parallel"() ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> () + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () return } // ----- func.func @num_threads_dims_mismatch(%n : i64) { - // expected-error@+1 {{num_threads dims(2) specified but 1 values provided}} + // expected-error@+1 {{dims(2) specified but 1 values provided}} omp.parallel num_threads(dims(2): %n : i64) { omp.terminator } @@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) { // ----- func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { - // expected-error@+1 {{num_threads dims and number of threads cannot be used together}} + // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} "omp.parallel"(%n, %n, %m) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> () + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () return } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
