================ @@ -3529,6 +3549,84 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +static llvm::Expected<llvm::Function *> +getOrCreateUserDefinedMapperFunc(Operation *declMapperOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + static llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap; + auto iter = userDefMapperMap.find(declMapperOp); + if (iter != userDefMapperMap.end()) + return iter->second; + llvm::Expected<llvm::Function *> mapperFunc = + emitUserDefinedMapper(declMapperOp, builder, moduleTranslation); + if (!mapperFunc) + return mapperFunc.takeError(); + userDefMapperMap.try_emplace(declMapperOp, *mapperFunc); + return mapperFunc; +} + +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast<omp::DeclareMapperOp>(op); + auto declMapperInfoOp = + *declMapperOp.getOps<omp::DeclareMapperInfoOp>().begin(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType()); + std::string mapperName = ompBuilder->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + SmallVector<Value> mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getRegion().getArgument(0), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error<PreviouslyReportedError>(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) { + if (combinedInfo.Mappers[i]) { + // Call the corresponding mapper function. + llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc( + combinedInfo.Mappers[i], builder, moduleTranslation); + assert(newFn && "Expect a valid mapper function is available"); ---------------- skatrak wrote:
This callback can fail, so it must return an `llvm::Expected<bool>`. This requires updating the types expected by the OMPIRBuilder to `function_ref<Expected<bool>(unsigned int, Function **)>` to every place these types of callbacks are passed, so I'd suggest giving it a name in OMPIRBuilder.h to simplify things a bit (e.g. `using CustomMapperCallbackTy = function_ref<Expected<bool>(unsigned int, Function **)>`). That will cause you to also have to check the result and forward errors caused by these callbacks when called inside of the OMPIRBuilder. Currently, the `assert` here will be removed in some builds, triggering another assertion, regardless of whether there was an error, due to not having checked the error status before trying to get the result below. Instead, this should be done: ```c++ if (!newFn) return newFn.takeError(); ``` https://github.com/llvm/llvm-project/pull/124746 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits