================ @@ -3421,6 +3441,85 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation); + +static llvm::Expected<llvm::Function *> +getOrCreateUserDefinedMapperFunc(Operation *declMapperOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap; + auto iter = userDefMapperMap.find(declMapperOp); + if (iter != userDefMapperMap.end()) + return iter->second; + llvm::Expected<llvm::Function *> mapperFunc = + emitUserDefinedMapper(declMapperOp, builder, moduleTranslation); + if (!mapperFunc) + return mapperFunc.takeError(); + userDefMapperMap.try_emplace(declMapperOp, *mapperFunc); + return userDefMapperMap.lookup(declMapperOp); +} + +static llvm::Expected<llvm::Function *> +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast<omp::DeclareMapperOp>(op); + auto declMapperInfoOp = + *declMapperOp.getOps<omp::DeclareMapperInfoOp>().begin(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = + moduleTranslation.convertType(declMapperOp.getVarType()); + std::string mapperName = ompBuilder->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + SmallVector<Value> mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getRegion().getArgument(0), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error<PreviouslyReportedError>(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) { + if (combinedInfo.Mappers[i]) { + // Call the corresponding mapper function. + llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc( ---------------- ergawy wrote:
> A declare mapper may refer to another mapper in it's mapping scheme .... Can you please add a test that activates this recursion to better understand how it works? I believe the 2 tests added in the PR don't achieve that. I am just a bit uncomfortable doing this without being sure how it works. In general, instead of the recursion, we can try to come up with a worklist algorithm that builds up the list of mapper functions that need to be generated. https://github.com/llvm/llvm-project/pull/124746 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits