================ @@ -3889,6 +3889,215 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg, return builder.saveIP(); } +/// Follow uses of `host_eval`-defined block arguments of the given `omp.target` +/// operation and populate output variables with their corresponding host value +/// (i.e. operand evaluated outside of the target region), based on their uses +/// inside of the target region. +/// +/// Loop bounds and steps are only optionally populated, if output vectors are +/// provided. +static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads, + Value &numTeamsLower, Value &numTeamsUpper, + Value &threadLimit) { + auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp); + for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(), + blockArgIface.getHostEvalBlockArgs())) { + Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item); + + for (Operation *user : blockArg.getUsers()) { + llvm::TypeSwitch<Operation *>(user) + .Case([&](omp::TeamsOp teamsOp) { + if (teamsOp.getNumTeamsLower() == blockArg) + numTeamsLower = hostEvalVar; + else if (teamsOp.getNumTeamsUpper() == blockArg) + numTeamsUpper = hostEvalVar; + else if (teamsOp.getThreadLimit() == blockArg) + threadLimit = hostEvalVar; + else + llvm_unreachable("unsupported host_eval use"); + }) + .Case([&](omp::ParallelOp parallelOp) { + if (parallelOp.getNumThreads() == blockArg) + numThreads = hostEvalVar; + else + llvm_unreachable("unsupported host_eval use"); + }) + .Case([&](omp::LoopNestOp loopOp) { + // TODO: Extract bounds and step values. + }) + .Default([](Operation *) { + llvm_unreachable("unsupported host_eval use"); + }); + } + } +} + +/// If \p op is of the given type parameter, return it casted to that type. +/// Otherwise, if its immediate parent operation (or some other higher-level +/// parent, if \p immediateParent is false) is of that type, return that parent +/// casted to the given type. +/// +/// If \p op is \c null or neither it or its parent(s) are of the specified +/// type, return a \c null operation. +template <typename OpTy> +static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) { + if (!op) + return OpTy(); + + if (OpTy casted = dyn_cast<OpTy>(op)) + return casted; + + if (immediateParent) + return dyn_cast_if_present<OpTy>(op->getParentOp()); + + return op->getParentOfType<OpTy>(); +} + +/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default +/// values as stated by the corresponding clauses, if constant. +/// +/// These default values must be set before the creation of the outlined LLVM +/// function for the target region, so that they can be used to initialize the +/// corresponding global `ConfigurationEnvironmentTy` structure. +static void +initTargetDefaultAttrs(omp::TargetOp targetOp, + llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs, + bool isTargetDevice) { + // TODO: Handle constant 'if' clauses. + Operation *capturedOp = targetOp.getInnermostCapturedOmpOp(); + + Value numThreads, numTeamsLower, numTeamsUpper, threadLimit; + if (!isTargetDevice) { + extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper, + threadLimit); + } else { + // In the target device, values for these clauses are not passed as + // host_eval, but instead evaluated prior to entry to the region. This + // ensures values are mapped and available inside of the target region. + if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) { + numTeamsLower = teamsOp.getNumTeamsLower(); + numTeamsUpper = teamsOp.getNumTeamsUpper(); + threadLimit = teamsOp.getThreadLimit(); + } + + if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) + numThreads = parallelOp.getNumThreads(); + } + + auto extractConstInteger = [](Value value) -> std::optional<int64_t> { ---------------- skatrak wrote:
Good idea, done. https://github.com/llvm/llvm-project/pull/116052 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits