https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/147972
>From 77a4183117cd259584c1bb4136aa27dd2b9548b0 Mon Sep 17 00:00:00 2001 From: Ross Brunton <r...@codeplay.com> Date: Thu, 10 Jul 2025 15:34:17 +0100 Subject: [PATCH] [Offload] Add global variable address/size queries Add two new symbol info types for getting the bounds of a global variable. As well as a number of tests for reading/writing to it. --- offload/liboffload/API/Symbol.td | 4 +- offload/liboffload/src/OffloadImpl.cpp | 19 ++++ offload/tools/offload-tblgen/PrintGen.cpp | 8 +- .../unittests/OffloadAPI/memory/olMemcpy.cpp | 105 ++++++++++++++++++ .../OffloadAPI/symbol/olGetSymbolInfo.cpp | 28 +++++ .../OffloadAPI/symbol/olGetSymbolInfoSize.cpp | 14 +++ 6 files changed, 175 insertions(+), 3 deletions(-) diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td index 9317c71df1f10..2e94d703809e7 100644 --- a/offload/liboffload/API/Symbol.td +++ b/offload/liboffload/API/Symbol.td @@ -39,7 +39,9 @@ def : Enum { let desc = "Supported symbol info."; let is_typed = 1; let etors = [ - TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol."> + TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">, + TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">, + TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">, ]; } diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 6d98c33ffb8da..17a2b00cb7140 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); + auto CheckKind = [&](ol_symbol_kind_t Required) { + if (Symbol->Kind != Required) { + std::string ErrBuffer; + llvm::raw_string_ostream(ErrBuffer) + << PropName << ": Expected a symbol of Kind " << Required + << " but given a symbol of Kind " << Symbol->Kind; + return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str()); + } + return Plugin::success(); + }; + switch (PropName) { case OL_SYMBOL_INFO_KIND: return Info.write<ol_symbol_kind_t>(Symbol->Kind); + case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS: + if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) + return Err; + return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr()); + case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE: + if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE)) + return Err; + return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize()); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetSymbolInfo enum '%i' is invalid", PropName); diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp index d1189688a90a3..89d7c820426cf 100644 --- a/offload/tools/offload-tblgen/PrintGen.cpp +++ b/offload/tools/offload-tblgen/PrintGen.cpp @@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_ if (Type == "char[]") { OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n"); } else { - OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", - Type); + if (Type == "void *") + OS << formatv(TAB_2 "void * const * const tptr = (void * " + "const * const)ptr;\n"); + else + OS << formatv( + TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type); // TODO: Handle other cases here OS << TAB_2 "os << (const void *)tptr << \" (\";\n"; if (Type.ends_with("*")) { diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp index c1762b451b81d..c1fb6df9bad0d 100644 --- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp @@ -13,6 +13,32 @@ using olMemcpyTest = OffloadQueueTest; OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest); +struct olMemcpyGlobalTest : OffloadGlobalTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp()); + ASSERT_SUCCESS( + olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel)); + ASSERT_SUCCESS( + olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel)); + ASSERT_SUCCESS(olCreateQueue(Device, &Queue)); + ASSERT_SUCCESS(olGetSymbolInfo( + Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr)); + + LaunchArgs.Dimensions = 1; + LaunchArgs.GroupSize = {64, 1, 1}; + LaunchArgs.NumGroups = {1, 1, 1}; + + LaunchArgs.DynSharedMemory = 0; + } + + ol_kernel_launch_size_args_t LaunchArgs{}; + void *Addr; + ol_symbol_handle_t ReadKernel; + ol_symbol_handle_t WriteKernel; + ol_queue_handle_t Queue; +}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest); + TEST_P(olMemcpyTest, SuccessHtoD) { constexpr size_t Size = 1024; void *Alloc; @@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) { ASSERT_SUCCESS( olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr)); } + +TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) { + void *SourceMem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + 64 * sizeof(uint32_t), &SourceMem)); + uint32_t *SourceData = (uint32_t *)SourceMem; + for (auto I = 0; I < 64; I++) + SourceData[I] = I; + + void *DestMem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + 64 * sizeof(uint32_t), &DestMem)); + + ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host, + 64 * sizeof(uint32_t), nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device, + 64 * sizeof(uint32_t), nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + + uint32_t *DestData = (uint32_t *)DestMem; + for (uint32_t I = 0; I < 64; I++) + ASSERT_EQ(DestData[I], I); + + ASSERT_SUCCESS(olMemFree(DestMem)); + ASSERT_SUCCESS(olMemFree(SourceMem)); +} + +TEST_P(olMemcpyGlobalTest, SuccessWrite) { + void *SourceMem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + LaunchArgs.GroupSize.x * sizeof(uint32_t), + &SourceMem)); + uint32_t *SourceData = (uint32_t *)SourceMem; + for (auto I = 0; I < 64; I++) + SourceData[I] = I; + + void *DestMem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + LaunchArgs.GroupSize.x * sizeof(uint32_t), + &DestMem)); + struct { + void *Mem; + } Args{DestMem}; + + ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host, + 64 * sizeof(uint32_t), nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args), + &LaunchArgs, nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + + uint32_t *DestData = (uint32_t *)DestMem; + for (uint32_t I = 0; I < 64; I++) + ASSERT_EQ(DestData[I], I); + + ASSERT_SUCCESS(olMemFree(DestMem)); + ASSERT_SUCCESS(olMemFree(SourceMem)); +} + +TEST_P(olMemcpyGlobalTest, SuccessRead) { + void *DestMem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + LaunchArgs.GroupSize.x * sizeof(uint32_t), + &DestMem)); + + ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0, + &LaunchArgs, nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device, + 64 * sizeof(uint32_t), nullptr)); + ASSERT_SUCCESS(olWaitQueue(Queue)); + + uint32_t *DestData = (uint32_t *)DestMem; + for (uint32_t I = 0; I < 64; I++) + ASSERT_EQ(DestData[I], I * 2); + + ASSERT_SUCCESS(olMemFree(DestMem)); +} diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp index 100a374430372..ed8f4716974cd 100644 --- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp +++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp @@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) { ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE); } +TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) { + void *RetrievedAddr; + ASSERT_ERROR(OL_ERRC_SYMBOL_KIND, + olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, + sizeof(RetrievedAddr), &RetrievedAddr)); +} + +TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) { + void *RetrievedAddr = nullptr; + ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, + sizeof(RetrievedAddr), &RetrievedAddr)); + ASSERT_NE(RetrievedAddr, nullptr); +} + +TEST_P(olGetSymbolInfoKernelTest, InvalidSize) { + size_t RetrievedSize; + ASSERT_ERROR(OL_ERRC_SYMBOL_KIND, + olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, + sizeof(RetrievedSize), &RetrievedSize)); +} + +TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) { + size_t RetrievedSize = 0; + ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, + sizeof(RetrievedSize), &RetrievedSize)); + ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t)); +} + TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) { ol_symbol_kind_t RetrievedKind; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp index aa7a061a9ef7a..ec011865cc6ad 100644 --- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp +++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp @@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) { ASSERT_EQ(Size, sizeof(ol_symbol_kind_t)); } +TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) { + size_t Size = 0; + ASSERT_SUCCESS(olGetSymbolInfoSize( + Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size)); + ASSERT_EQ(Size, sizeof(void *)); +} + +TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) { + size_t Size = 0; + ASSERT_SUCCESS( + olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size)); + ASSERT_EQ(Size, sizeof(size_t)); +} + TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) { size_t Size = 0; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits