================ @@ -2081,6 +2083,79 @@ void collectMapDataFromMapOperands(MapInfoData &mapData, } } +static int getMapDataMemberIdx(MapInfoData &mapData, + mlir::omp::MapInfoOp memberOp) { + auto *res = llvm::find(mapData.MapClause, memberOp); + assert(res != mapData.MapClause.end()); + return std::distance(mapData.MapClause.begin(), res); +} + +static mlir::omp::MapInfoOp +getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) { + // Only 1 member has been mapped, we can return it. + if (mapInfo.getMembersIndex()->size() == 1) + if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>( + mapInfo.getMembers()[0].getDefiningOp())) + return mapOp; + + std::vector<size_t> indices( + mapInfo.getMembersIndexAttr().getShapedType().getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + + llvm::sort( + indices.begin(), indices.end(), [&](const size_t a, const size_t b) { + for (int i = 0; + i < mapInfo.getMembersIndexAttr().getShapedType().getShape()[1]; + ++i) { + int aIndex = + mapInfo.getMembersIndexAttr() + .getValues<int32_t>()[a * mapInfo.getMembersIndexAttr() + .getShapedType() + .getShape()[1] + + i]; + int bIndex = + mapInfo.getMembersIndexAttr() + .getValues<int32_t>()[b * mapInfo.getMembersIndexAttr() + .getShapedType() + .getShape()[1] + + i]; + + // As we have iterated to a stage where both indices are invalid + // we likely have the same member index, possibly the same member + // being mapped, return the first. + if (aIndex == -1 && bIndex == -1) + return true; + + if (aIndex == -1) + return true; + + if (bIndex == -1) + return false; + + // A is earlier in the record type layout than B + if (aIndex < bIndex) + return true; + + if (bIndex < aIndex) + return false; + } + + // iterated the entire list and couldn't make a decision, all elements + // were likely the same, return true for now similar to reaching the end + // of both and finding invalid indices. + return true; + }); + + if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>( + mapInfo.getMembers()[((first) ? indices.front() : indices.back())] + .getDefiningOp())) + return mapOp; + + assert(false && "getFirstOrLastMappedMemberPtr could not find approproaite " + "map information"); + return {}; ---------------- skatrak wrote:
Nit: Replace with `llvm_unreachable()` for a custom failure message or `return llvm::cast<mlir::omp::MapInfoOp>(...)`, since that already asserts. https://github.com/llvm/llvm-project/pull/82852 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits