llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-llvm Author: Chaitanya (skc7) <details> <summary>Changes</summary> PR adds support of openmp 6.1 feature num_threads with dims modifier. llvmIR translation for num_threads with dims modifier is marked as NYI. --- Full diff: https://github.com/llvm/llvm-project/pull/171767.diff 6 Files Affected: - (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+42-3) - (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+2) - (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+70-7) - (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-1) - (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+32-1) - (modified) mlir/test/Dialect/OpenMP/ops.mlir (+10-5) ``````````diff diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index e36dc7c246f01..09c1d4a8a5866 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1069,16 +1069,55 @@ class OpenMP_NumThreadsClauseSkip< > : OpenMP_Clause<traits, arguments, assemblyFormat, description, extraClassDeclaration> { let arguments = (ins + ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims, + Variadic<AnyInteger>:$num_threads_dims_values, Optional<IntLikeType>:$num_threads ); let optAssemblyFormat = [{ - `num_threads` `(` $num_threads `:` type($num_threads) `)` + `num_threads` `(` custom<NumThreadsClause>( + $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values), + $num_threads, type($num_threads) + ) `)` }]; let description = [{ - The optional `num_threads` parameter specifies the number of threads which - should be used to execute the parallel region. + num_threads clause specifies the desired number of threads in the team + space formed by the construct on which it appears. + + With dims modifier: + - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list) + - Specifies upper bounds for each dimension (all must have same type) + - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)` + - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)` + + Without dims modifier: + - Uses `num_threads` + - If lower bound not specified, it defaults to upper bound value + - Format: `num_threads(bounds : type)` + - Example: `num_threads(%ub : i32)` + }]; + + let extraClassDeclaration = [{ + /// Returns true if the dims modifier is explicitly present + bool hasNumThreadsDimsModifier() { + return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value(); + } + + /// Returns the number of dimensions specified by dims modifier + unsigned getNumThreadsDimsCount() { + if (!hasNumThreadsDimsModifier()) + return 1; + return static_cast<unsigned>(*getNumThreadsNumDims()); + } + + /// Returns the value for a specific dimension index + /// Index must be less than getNumThreadsDimsCount() + ::mlir::Value getNumThreadsDimsValue(unsigned index) { + assert(index < getNumThreadsDimsCount() && + "Num threads dims index out of bounds"); + return getNumThreadsDimsValues()[index]; + } }]; } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 6423d49859c97..ab7bded7835be 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /* allocate_vars = */ llvm::SmallVector<Value>{}, /* allocator_vars = */ llvm::SmallVector<Value>{}, /* if_expr = */ Value{}, + /* num_threads_num_dims = */ nullptr, + /* num_threads_dims_values = */ llvm::SmallVector<Value>{}, /* num_threads = */ numThreadsVar, /* private_vars = */ ValueRange(), /* private_syms = */ nullptr, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index d4dbf5f5244df..a9ed0274cd21c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ArrayRef<NamedAttribute> attributes) { ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, + /*num_threads_dims=*/nullptr, + /*num_threads_values=*/ValueRange(), /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, /*proc_bind_kind=*/nullptr, @@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, void ParallelOp::build(OpBuilder &builder, OperationState &state, const ParallelOperands &clauses) { MLIRContext *ctx = builder.getContext(); - ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, - clauses.ifExpr, clauses.numThreads, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.procBindKind, - clauses.reductionMod, clauses.reductionVars, - makeDenseBoolArrayAttr(ctx, clauses.reductionByref), - makeArrayAttr(ctx, clauses.reductionSyms)); + ParallelOp::build( + builder, state, clauses.allocateVars, clauses.allocatorVars, + clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues, + clauses.numThreads, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.procBindKind, clauses.reductionMod, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionByref), + makeArrayAttr(ctx, clauses.reductionSyms)); } template <typename OpType> @@ -2596,14 +2599,39 @@ static LogicalResult verifyPrivateVarList(OpType &op) { return success(); } +// Helper: Verify num_threads clause +LogicalResult +verifyNumThreadsClause(Operation *op, + std::optional<IntegerAttr> numThreadsNumDims, + OperandRange numThreadsDimsValues, Value numThreads) { + bool hasDimsModifier = + numThreadsNumDims.has_value() && numThreadsNumDims.value(); + if (hasDimsModifier && numThreads) { + return op->emitError("num_threads with dims modifier cannot be used " + "together with number of threads"); + } + if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues))) + return failure(); + return success(); +} + LogicalResult ParallelOp::verify() { + // verify num_threads clause restrictions + if (failed(verifyNumThreadsClause( + getOperation(), this->getNumThreadsNumDimsAttr(), + this->getNumThreadsDimsValues(), this->getNumThreads()))) + return failure(); + + // verify allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); + // verify private variables restrictions if (failed(verifyPrivateVarList(*this))) return failure(); + // verify reduction variables restrictions return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), getReductionByref()); } @@ -4647,6 +4675,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op, } } +//===----------------------------------------------------------------------===// +// Parser and printer for num_threads clause +//===----------------------------------------------------------------------===// +static ParseResult +parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, + SmallVectorImpl<Type> &types, + std::optional<OpAsmParser::UnresolvedOperand> &bounds, + Type &boundsType) { + if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) { + return success(); + } + + OpAsmParser::UnresolvedOperand boundsOperand; + if (parser.parseOperand(boundsOperand) || parser.parseColon() || + parser.parseType(boundsType)) { + return failure(); + } + bounds = boundsOperand; + return success(); +} + +static void printNumThreadsClause(OpAsmPrinter &p, Operation *op, + IntegerAttr dimsAttr, OperandRange values, + TypeRange types, Value bounds, + Type boundsType) { + if (!values.empty()) { + printDimsModifierWithValues(p, dimsAttr, values, types); + } + if (bounds) { + p.printOperand(bounds); + p << " : " << boundsType; + } +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 00f782e87d5af..2bfb9fb2211c4 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2879,6 +2879,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, if (auto ifVar = opInst.getIfExpr()) ifCond = moduleTranslation.lookupValue(ifVar); llvm::Value *numThreads = nullptr; + // num_threads dims and values are not yet supported + assert(!opInst.hasNumThreadsDimsModifier() && + "Lowering of num_threads with dims modifier is NYI."); if (auto numThreadsVar = opInst.getNumThreads()) numThreads = moduleTranslation.lookupValue(numThreadsVar); auto pbKind = llvm::omp::OMP_PROC_BIND_default; @@ -5604,6 +5607,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, llvm_unreachable("unsupported host_eval use"); }) .Case([&](omp::ParallelOp parallelOp) { + // num_threads dims and values are not yet supported + assert(!parallelOp.hasNumThreadsDimsModifier() && + "Lowering of num_threads with dims modifier is NYI."); if (parallelOp.getNumThreads() == blockArg) numThreads = hostEvalVar; else @@ -5724,8 +5730,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, threadLimit = teamsOp.getThreadLimit(); } - if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) + if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) { + // num_threads dims and values are not yet supported + assert(!parallelOp.hasNumThreadsDimsModifier() && + "Lowering of num_threads with dims modifier is NYI."); numThreads = parallelOp.getNumThreads(); + } } // Handle clauses impacting the number of teams. diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index dd367aba8da27..db0ddcb415d42 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) { // ----- +func.func @num_threads_dims_no_values() { + // expected-error@+1 {{dims modifier requires values to be specified}} + "omp.parallel"() ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> () + return +} + +// ----- + +func.func @num_threads_dims_mismatch(%n : i64) { + // expected-error@+1 {{dims(2) specified but 1 values provided}} + omp.parallel num_threads(dims(2): %n : i64) { + omp.terminator + } + + return +} + +// ----- + +func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) { + // expected-error@+1 {{num_threads with dims modifier cannot be used together with number of threads}} + "omp.parallel"(%n, %n, %m) ({ + omp.terminator + }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> () + return +} + +// ----- + func.func @nowait_not_allowed(%n : memref<i32>) { // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} @@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) { // ----- func.func @undefined_privatizer(%arg0: !llvm.ptr) { // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}} - "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ + "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({ ^bb0(%arg2: !llvm.ptr): omp.terminator }) : (!llvm.ptr) -> () diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 3633a4be1eb62..585c9483c08a9 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32) "omp.parallel"(%data_var, %data_var, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> () // CHECK: omp.barrier omp.barrier @@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}}) "omp.parallel"(%data_var, %data_var, %if_cond) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> () + }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> () omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () + }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> () + }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> () // CHECK: omp.parallel omp.parallel { @@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre omp.terminator } + // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64) + omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) { + omp.terminator + } + // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) { omp.terminator `````````` </details> https://github.com/llvm/llvm-project/pull/171767 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
