https://github.com/RossBrunton created 
https://github.com/llvm/llvm-project/pull/148648

A version of `olCreateProgram` that inputs many bitcode files and links
them together before loading them.


>From 8589fcc6d053cb2937cf970d1ce354abfb84da31 Mon Sep 17 00:00:00 2001
From: Ross Brunton <r...@codeplay.com>
Date: Mon, 14 Jul 2025 16:05:41 +0100
Subject: [PATCH] [Offload] Add `olLinkProgram`

A version of `olCreateProgram` that inputs many bitcode files and links
them together before loading them.
---
 offload/liboffload/API/Program.td             | 28 ++++++
 offload/liboffload/src/OffloadImpl.cpp        | 33 +++++++
 offload/plugins-nextgen/common/include/JIT.h  |  4 +
 .../common/include/PluginInterface.h          |  4 +
 offload/plugins-nextgen/common/src/JIT.cpp    | 41 ++++++++
 .../common/src/PluginInterface.cpp            |  7 ++
 offload/unittests/OffloadAPI/CMakeLists.txt   |  3 +-
 .../OffloadAPI/device_code/CMakeLists.txt     |  4 +
 .../unittests/OffloadAPI/device_code/link_a.c | 11 +++
 .../unittests/OffloadAPI/device_code/link_b.c | 10 ++
 .../OffloadAPI/program/olLinkProgram.cpp      | 99 +++++++++++++++++++
 11 files changed, 243 insertions(+), 1 deletion(-)
 create mode 100644 offload/unittests/OffloadAPI/device_code/link_a.c
 create mode 100644 offload/unittests/OffloadAPI/device_code/link_b.c
 create mode 100644 offload/unittests/OffloadAPI/program/olLinkProgram.cpp

diff --git a/offload/liboffload/API/Program.td 
b/offload/liboffload/API/Program.td
index 0476fa1f7c27a..3dae37f288ff7 100644
--- a/offload/liboffload/API/Program.td
+++ b/offload/liboffload/API/Program.td
@@ -25,6 +25,34 @@ def : Function {
     let returns = [];
 }
 
+def : Struct {
+    let name = "ol_program_link_buffer_t";
+    let desc = "An image to link with `olLinkProgram`.";
+    let members = [
+        StructMember<"void *", "Address", "base address of memory image">,
+        StructMember<"size_t", "Size", "size in bytes of memory image">,
+    ];
+}
+
+def : Function {
+    let name = "olLinkProgram";
+    let desc = "Compile and link multiple bitcode images into a single 
binary.";
+    let details = [
+        "No caching is performed; multiple calls to `olLinkProgram` with the 
same images will result in multiple linking operations",
+    ];
+    let params = [
+        Param<"ol_device_handle_t", "Device", "handle of the device to link 
for", PARAM_IN>,
+        Param<"ol_program_link_buffer_t *", "Images", "a pointer to an array 
of `ImagesSize` entries, one for each image to link", PARAM_IN>,
+        Param<"size_t", "ImagesSize", "the number of elements in `Images`", 
PARAM_IN>,
+        Param<"ol_program_handle_t*", "Program", "output handle for the 
created program", PARAM_OUT>
+    ];
+    let returns = [
+        Return<"OL_ERRC_INVALID_SIZE", ["`ImagesSize == 0`"]>,
+        Return<"OL_ERRC_INVALID_BINARY", ["Any image is not in the bitcode 
format"]>,
+        Return<"OL_ERRC_UNSUPPORTED", ["Linking is not supported for this 
device and `ImagesSize` > 1"]>,
+    ];
+}
+
 def : Function {
     let name = "olDestroyProgram";
     let desc = "Destroy the program and free all underlying resources.";
diff --git a/offload/liboffload/src/OffloadImpl.cpp 
b/offload/liboffload/src/OffloadImpl.cpp
index 17a2b00cb7140..14af015460c8c 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -650,6 +650,39 @@ Error olCreateProgram_impl(ol_device_handle_t Device, 
const void *ProgData,
   return Error::success();
 }
 
+Error olLinkProgram_impl(ol_device_impl_t *Device,
+                         ol_program_link_buffer_t *Images, size_t ImagesSize,
+                         ol_program_handle_t *Program) {
+  std::vector<__tgt_device_image> DevImages;
+  for (size_t I = 0; I < ImagesSize; I++) {
+    auto &ProgData = Images[I];
+    DevImages.push_back({ProgData.Address,
+                         utils::advancePtr(ProgData.Address, ProgData.Size),
+                         nullptr, nullptr});
+  }
+
+  auto LinkResult =
+      Device->Device->jitLinkBinary(Device->Device->Plugin, DevImages);
+  if (!LinkResult)
+    return LinkResult.takeError();
+
+  ol_program_handle_t Prog =
+      new ol_program_impl_t(nullptr, nullptr, *LinkResult);
+
+  auto Res =
+      Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
+  if (!Res) {
+    delete Prog;
+    return Res.takeError();
+  }
+  assert(*Res != nullptr && "loadBinary returned nullptr");
+
+  Prog->Image = *Res;
+  *Program = Prog;
+
+  return Error::success();
+}
+
 Error olDestroyProgram_impl(ol_program_handle_t Program) {
   auto &Device = Program->Image->getDevice();
   if (auto Err = Device.unloadBinary(Program->Image))
diff --git a/offload/plugins-nextgen/common/include/JIT.h 
b/offload/plugins-nextgen/common/include/JIT.h
index 1d6280a0af141..08b82c4aefb8d 100644
--- a/offload/plugins-nextgen/common/include/JIT.h
+++ b/offload/plugins-nextgen/common/include/JIT.h
@@ -55,6 +55,10 @@ struct JITEngine {
   process(const __tgt_device_image &Image,
           target::plugin::GenericDeviceTy &Device);
 
+  /// Link and compile multiple bitcode images into a single binary
+  Expected<__tgt_device_image> link(std::vector<__tgt_device_image> &Images,
+                                    target::plugin::GenericDeviceTy &Device);
+
 private:
   /// Compile the bitcode image \p Image and generate the binary image that can
   /// be loaded to the target device of the triple \p Triple architecture \p
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h 
b/offload/plugins-nextgen/common/include/PluginInterface.h
index 7824257d28e1f..79e021cc64f3b 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -749,6 +749,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   /// Load the binary image into the device and return the target table.
   Expected<DeviceImageTy *> loadBinary(GenericPluginTy &Plugin,
                                        const __tgt_device_image *TgtImage);
+  /// Link and compile multiple bitcode images into a single image.
+  Expected<__tgt_device_image>
+  jitLinkBinary(GenericPluginTy &Plugin,
+                std::vector<__tgt_device_image> InputImages);
   virtual Expected<DeviceImageTy *>
   loadBinaryImpl(const __tgt_device_image *TgtImage, int32_t ImageId) = 0;
 
diff --git a/offload/plugins-nextgen/common/src/JIT.cpp 
b/offload/plugins-nextgen/common/src/JIT.cpp
index 835dcc0da2ec9..2cf6ddbfdff0b 100644
--- a/offload/plugins-nextgen/common/src/JIT.cpp
+++ b/offload/plugins-nextgen/common/src/JIT.cpp
@@ -327,3 +327,44 @@ JITEngine::process(const __tgt_device_image &Image,
 
   return &Image;
 }
+
+Expected<__tgt_device_image>
+JITEngine::link(std::vector<__tgt_device_image> &Images,
+                target::plugin::GenericDeviceTy &Device) {
+  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
+  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
+
+  PostProcessingFn PostProcessing =
+      [&Device](llvm::SmallVector<std::unique_ptr<MemoryBuffer>> &&MB)
+      -> Expected<std::unique_ptr<MemoryBuffer>> {
+    return Device.doJITPostProcessing(std::move(MB));
+  };
+
+  std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
+
+  llvm::SmallVector<std::unique_ptr<MemoryBuffer>> Buffers;
+  size_t Index = 0;
+  for (auto &I : Images) {
+    if (!isImageBitcode(I))
+      return error::createOffloadError(
+          error::ErrorCode::INVALID_BINARY,
+          "binary %i provided to link operation is not bitcode", Index);
+    Index++;
+
+    auto ObjMBOrErr = getOrCreateObjFile(I, CUI.Context, ComputeUnitKind);
+    if (!ObjMBOrErr)
+      return ObjMBOrErr.takeError();
+    Buffers.push_back(std::move(*ObjMBOrErr));
+  }
+
+  auto ImageMBOrErr = PostProcessing(std::move(Buffers));
+  if (!ImageMBOrErr)
+    return ImageMBOrErr.takeError();
+
+  auto &ImageMB = CUI.JITImages.emplace_back(std::move(*ImageMBOrErr));
+  __tgt_device_image JITedImage{};
+  JITedImage.ImageStart = const_cast<char *>(ImageMB->getBufferStart());
+  JITedImage.ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
+
+  return JITedImage;
+}
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp 
b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 81b9d423e13d8..9e2234dcc148b 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -903,6 +903,13 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
 
   return deinitImpl();
 }
+
+Expected<__tgt_device_image>
+GenericDeviceTy::jitLinkBinary(GenericPluginTy &Plugin,
+                               std::vector<__tgt_device_image> InputImages) {
+  return Plugin.getJIT().link(InputImages, *this);
+}
+
 Expected<DeviceImageTy *>
 GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
                             const __tgt_device_image *InputTgtImage) {
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt 
b/offload/unittests/OffloadAPI/CMakeLists.txt
index d76338612210d..e13ded4f8b1aa 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -32,7 +32,8 @@ add_offload_unittest("platform"
 
 add_offload_unittest("program"
     program/olCreateProgram.cpp
-    program/olDestroyProgram.cpp)
+    program/olDestroyProgram.cpp
+    program/olLinkProgram.cpp)
 
 add_offload_unittest("queue"
     queue/olCreateQueue.cpp
diff --git a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt 
b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
index 11c8ccbd6c7c5..c3e07724086fe 100644
--- a/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/device_code/CMakeLists.txt
@@ -8,6 +8,8 @@ add_offload_test_device_code(localmem_static.c localmem_static)
 add_offload_test_device_code(global.c global)
 add_offload_test_device_code(global_ctor.c global_ctor)
 add_offload_test_device_code(global_dtor.c global_dtor)
+add_offload_test_device_code(link_a.c link_a)
+add_offload_test_device_code(link_b.c link_b)
 
 add_custom_target(offload_device_binaries DEPENDS
     foo.bin
@@ -19,5 +21,7 @@ add_custom_target(offload_device_binaries DEPENDS
     global.bin
     global_ctor.bin
     global_dtor.bin
+    link_a.bin
+    link_b.bin
 )
 set(OFFLOAD_TEST_DEVICE_CODE_PATH ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
diff --git a/offload/unittests/OffloadAPI/device_code/link_a.c 
b/offload/unittests/OffloadAPI/device_code/link_a.c
new file mode 100644
index 0000000000000..7feb92189c018
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/link_a.c
@@ -0,0 +1,11 @@
+#include <gpuintrin.h>
+#include <stdint.h>
+
+uint32_t global;
+
+extern uint32_t funky();
+
+__gpu_kernel void link_a(uint32_t *out) {
+  out[0] = funky();
+  out[1] = global;
+}
diff --git a/offload/unittests/OffloadAPI/device_code/link_b.c 
b/offload/unittests/OffloadAPI/device_code/link_b.c
new file mode 100644
index 0000000000000..82f41fd8a0218
--- /dev/null
+++ b/offload/unittests/OffloadAPI/device_code/link_b.c
@@ -0,0 +1,10 @@
+#include <gpuintrin.h>
+#include <stdint.h>
+
+extern uint32_t global[64];
+
+[[gnu::visibility("default")]]
+uint32_t funky() {
+  global[0] = 100;
+  return 200;
+}
diff --git a/offload/unittests/OffloadAPI/program/olLinkProgram.cpp 
b/offload/unittests/OffloadAPI/program/olLinkProgram.cpp
new file mode 100644
index 0000000000000..122d0156e6e0c
--- /dev/null
+++ b/offload/unittests/OffloadAPI/program/olLinkProgram.cpp
@@ -0,0 +1,99 @@
+//===------- Offload API tests - olCreateProgram 
--------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+using olLinkProgramTest = OffloadQueueTest;
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLinkProgramTest);
+
+TEST_P(olLinkProgramTest, SuccessSingle) {
+
+  std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
+  ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
+  ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
+
+  ol_program_link_buffer_t Buffers[1] = {
+      {const_cast<char *>(DeviceBin->getBufferStart()),
+       DeviceBin->getBufferSize()},
+  };
+
+  ol_program_handle_t Program;
+  ASSERT_SUCCESS(olLinkProgram(Device, Buffers, 1, &Program));
+  ASSERT_NE(Program, nullptr);
+
+  ASSERT_SUCCESS(olDestroyProgram(Program));
+}
+
+TEST_P(olLinkProgramTest, SuccessBuild) {
+  std::unique_ptr<llvm::MemoryBuffer> ABin;
+  ASSERT_TRUE(TestEnvironment::loadDeviceBinary("link_a", Device, ABin));
+  std::unique_ptr<llvm::MemoryBuffer> BBin;
+  ASSERT_TRUE(TestEnvironment::loadDeviceBinary("link_b", Device, BBin));
+
+  ol_program_link_buffer_t Buffers[2] = {
+      {const_cast<char *>(ABin->getBufferStart()), ABin->getBufferSize()},
+      {const_cast<char *>(BBin->getBufferStart()), BBin->getBufferSize()},
+  };
+
+  ol_program_handle_t Program;
+  auto LinkResult = olLinkProgram(Device, Buffers, 2, &Program);
+  if (LinkResult && LinkResult->Code == OL_ERRC_UNSUPPORTED)
+    GTEST_SKIP() << "Linking unsupported: " << LinkResult->Details;
+  ASSERT_SUCCESS(LinkResult);
+  ASSERT_NE(Program, nullptr);
+
+  ol_symbol_handle_t Kernel;
+  ASSERT_SUCCESS(
+      olGetSymbol(Program, "link_a", OL_SYMBOL_KIND_KERNEL, &Kernel));
+
+  void *Mem;
+  ASSERT_SUCCESS(
+      olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 2 * sizeof(uint32_t), &Mem));
+  struct {
+    void *Mem;
+  } Args{Mem};
+  ol_kernel_launch_size_args_t LaunchArgs{};
+  LaunchArgs.GroupSize = {1, 1, 1};
+  LaunchArgs.NumGroups = {1, 1, 1};
+  LaunchArgs.Dimensions = 1;
+
+  ASSERT_SUCCESS(olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args),
+                                &LaunchArgs, nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *Data = (uint32_t *)Mem;
+  ASSERT_EQ(Data[0], 200);
+  ASSERT_EQ(Data[1], 100);
+
+  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olDestroyProgram(Program));
+}
+
+TEST_P(olLinkProgramTest, InvalidNotBitcode) {
+  char FakeElf[] =
+      "\177ELF0000000000000000000000000000000000000000000000000000"
+      "00000000000000000000000000000000000000000000000000000000000";
+
+  ol_program_link_buffer_t Buffers[1] = {
+      {FakeElf, sizeof(FakeElf)},
+  };
+
+  ol_program_handle_t Program;
+  ASSERT_ERROR(OL_ERRC_INVALID_BINARY,
+               olLinkProgram(Device, Buffers, 1, &Program));
+}
+
+TEST_P(olLinkProgramTest, InvalidSize) {
+  ol_program_link_buffer_t Buffers[0] = {};
+
+  ol_program_handle_t Program;
+  ASSERT_ERROR(OL_ERRC_INVALID_SIZE,
+               olLinkProgram(Device, Buffers, 0, &Program));
+}

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to