https://github.com/RossBrunton updated 
https://github.com/llvm/llvm-project/pull/147972

>From 77a4183117cd259584c1bb4136aa27dd2b9548b0 Mon Sep 17 00:00:00 2001
From: Ross Brunton <r...@codeplay.com>
Date: Thu, 10 Jul 2025 15:34:17 +0100
Subject: [PATCH] [Offload] Add global variable address/size queries

Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.
---
 offload/liboffload/API/Symbol.td              |   4 +-
 offload/liboffload/src/OffloadImpl.cpp        |  19 ++++
 offload/tools/offload-tblgen/PrintGen.cpp     |   8 +-
 .../unittests/OffloadAPI/memory/olMemcpy.cpp  | 105 ++++++++++++++++++
 .../OffloadAPI/symbol/olGetSymbolInfo.cpp     |  28 +++++
 .../OffloadAPI/symbol/olGetSymbolInfoSize.cpp |  14 +++
 6 files changed, 175 insertions(+), 3 deletions(-)

diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td
index 9317c71df1f10..2e94d703809e7 100644
--- a/offload/liboffload/API/Symbol.td
+++ b/offload/liboffload/API/Symbol.td
@@ -39,7 +39,9 @@ def : Enum {
   let desc = "Supported symbol info.";
   let is_typed = 1;
   let etors = [
-    TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
+    TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
+    TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for 
this global variable.">,
+    TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this 
global variable.">,
   ];
 }
 
diff --git a/offload/liboffload/src/OffloadImpl.cpp 
b/offload/liboffload/src/OffloadImpl.cpp
index 6d98c33ffb8da..17a2b00cb7140 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
                                 void *PropValue, size_t *PropSizeRet) {
   InfoWriter Info(PropSize, PropValue, PropSizeRet);
 
+  auto CheckKind = [&](ol_symbol_kind_t Required) {
+    if (Symbol->Kind != Required) {
+      std::string ErrBuffer;
+      llvm::raw_string_ostream(ErrBuffer)
+          << PropName << ": Expected a symbol of Kind " << Required
+          << " but given a symbol of Kind " << Symbol->Kind;
+      return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
+    }
+    return Plugin::success();
+  };
+
   switch (PropName) {
   case OL_SYMBOL_INFO_KIND:
     return Info.write<ol_symbol_kind_t>(Symbol->Kind);
+  case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
+    if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+      return Err;
+    return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
+  case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
+    if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+      return Err;
+    return 
Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
   default:
     return createOffloadError(ErrorCode::INVALID_ENUMERATION,
                               "olGetSymbolInfo enum '%i' is invalid", 
PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp 
b/offload/tools/offload-tblgen/PrintGen.cpp
index d1189688a90a3..89d7c820426cf 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void 
*ptr, {0} value, size_
     if (Type == "char[]") {
       OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
     } else {
-      OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
-                    Type);
+      if (Type == "void *")
+        OS << formatv(TAB_2 "void * const * const tptr = (void * "
+                            "const * const)ptr;\n");
+      else
+        OS << formatv(
+            TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
       // TODO: Handle other cases here
       OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
       if (Type.ends_with("*")) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp 
b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index c1762b451b81d..c1fb6df9bad0d 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -13,6 +13,32 @@
 using olMemcpyTest = OffloadQueueTest;
 OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
 
+struct olMemcpyGlobalTest : OffloadGlobalTest {
+  void SetUp() override {
+    RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
+    ASSERT_SUCCESS(
+        olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel));
+    ASSERT_SUCCESS(
+        olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel));
+    ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+    ASSERT_SUCCESS(olGetSymbolInfo(
+        Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
+
+    LaunchArgs.Dimensions = 1;
+    LaunchArgs.GroupSize = {64, 1, 1};
+    LaunchArgs.NumGroups = {1, 1, 1};
+
+    LaunchArgs.DynSharedMemory = 0;
+  }
+
+  ol_kernel_launch_size_args_t LaunchArgs{};
+  void *Addr;
+  ol_symbol_handle_t ReadKernel;
+  ol_symbol_handle_t WriteKernel;
+  ol_queue_handle_t Queue;
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
+
 TEST_P(olMemcpyTest, SuccessHtoD) {
   constexpr size_t Size = 1024;
   void *Alloc;
@@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
   ASSERT_SUCCESS(
       olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
 }
+
+TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
+  void *SourceMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            64 * sizeof(uint32_t), &SourceMem));
+  uint32_t *SourceData = (uint32_t *)SourceMem;
+  for (auto I = 0; I < 64; I++)
+    SourceData[I] = I;
+
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            64 * sizeof(uint32_t), &DestMem));
+
+  ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessWrite) {
+  void *SourceMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &SourceMem));
+  uint32_t *SourceData = (uint32_t *)SourceMem;
+  for (auto I = 0; I < 64; I++)
+    SourceData[I] = I;
+
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &DestMem));
+  struct {
+    void *Mem;
+  } Args{DestMem};
+
+  ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
+                                &LaunchArgs, nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessRead) {
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &DestMem));
+
+  ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
+                                &LaunchArgs, nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I * 2);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+}
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp 
b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
index 100a374430372..ed8f4716974cd 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
   ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
 }
 
+TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
+  void *RetrievedAddr;
+  ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+               olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+                               sizeof(RetrievedAddr), &RetrievedAddr));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
+  void *RetrievedAddr = nullptr;
+  ASSERT_SUCCESS(olGetSymbolInfo(Global, 
OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+                                 sizeof(RetrievedAddr), &RetrievedAddr));
+  ASSERT_NE(RetrievedAddr, nullptr);
+}
+
+TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
+  size_t RetrievedSize;
+  ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+               olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+                               sizeof(RetrievedSize), &RetrievedSize));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
+  size_t RetrievedSize = 0;
+  ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+                                 sizeof(RetrievedSize), &RetrievedSize));
+  ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
+}
+
 TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
   ol_symbol_kind_t RetrievedKind;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp 
b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
index aa7a061a9ef7a..ec011865cc6ad 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
   ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
 }
 
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
+  size_t Size = 0;
+  ASSERT_SUCCESS(olGetSymbolInfoSize(
+      Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
+  ASSERT_EQ(Size, sizeof(void *));
+}
+
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
+  size_t Size = 0;
+  ASSERT_SUCCESS(
+      olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
+  ASSERT_EQ(Size, sizeof(size_t));
+}
+
 TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
   size_t Size = 0;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to