================
@@ -0,0 +1,223 @@
+//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements C wrappers around the sycl runtime library.
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cassert>
+#include <cfloat>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <stdexcept>
+#include <tuple>
+#include <vector>
+
+#include <CL/sycl.hpp>
+#include <level_zero/ze_api.h>
+#include <map>
+#include <mutex>
+#include <sycl/ext/oneapi/backend/level_zero.hpp>
+
+#ifdef _WIN32
+#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
+#else
+#define SYCL_RUNTIME_EXPORT
+#endif // _WIN32
+
+namespace {
+
+template <typename F>
+auto catchAll(F &&func) {
+  try {
+    return func();
+  } catch (const std::exception &e) {
+    fprintf(stdout, "An exception was thrown: %s\n", e.what());
+    fflush(stdout);
+    abort();
+  } catch (...) {
+    fprintf(stdout, "An unknown exception was thrown\n");
+    fflush(stdout);
+    abort();
+  }
+}
+
+#define L0_SAFE_CALL(call)                                                     
\
+  {                                                                            
\
+    ze_result_t status = (call);                                               
\
+    if (status != ZE_RESULT_SUCCESS) {                                         
\
+      fprintf(stdout, "L0 error %d\n", status);                                
\
+      fflush(stdout);                                                          
\
+      abort();                                                                 
\
+    }                                                                          
\
+  }
+
+} // namespace
+
+static sycl::device getDefaultDevice() {
+  auto platformList = sycl::platform::get_platforms();
+  for (const auto &platform : platformList) {
+    auto platformName = platform.get_info<sycl::info::platform::name>();
+    bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
+    if (!isLevelZero)
+      continue;
+
+    return platform.get_devices()[0];
+  }
+  throw std::runtime_error("getDefaultDevice failed");
+}
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wglobal-constructors"
+
+// Create global device and context
+sycl::device syclDevice = getDefaultDevice();
+sycl::context syclContext = sycl::context(syclDevice);
+
+#pragma clang diagnostic pop
+
+struct QUEUE {
+  sycl::queue syclQueue_;
+
+  QUEUE() { syclQueue_ = sycl::queue(syclContext, syclDevice); }
+};
+
+static void *allocDeviceMemory(QUEUE *queue, size_t size, bool isShared) {
+  void *memPtr = nullptr;
+  if (isShared) {
+    memPtr = sycl::aligned_alloc_shared(64, size, syclDevice, syclContext);
+  } else {
+    memPtr = sycl::aligned_alloc_device(64, size, syclDevice, syclContext);
+  }
+  if (memPtr == nullptr) {
+    throw std::runtime_error("mem allocation failed!");
+  }
+  return memPtr;
+}
+
+static void deallocDeviceMemory(QUEUE *queue, void *ptr) {
+  sycl::free(ptr, queue->syclQueue_);
+}
+
+static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
+  assert(data);
+  ze_module_handle_t zeModule;
+  ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
+                           nullptr,
+                           ZE_MODULE_FORMAT_IL_SPIRV,
+                           dataSize,
+                           (const uint8_t *)data,
+                           nullptr,
+                           nullptr};
+  auto zeDevice =
+      sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclDevice);
+  auto zeContext =
+      sycl::get_native<sycl::backend::ext_oneapi_level_zero>(syclContext);
+  L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
+  return zeModule;
+}
+
+static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
+  assert(zeModule);
+  assert(name);
+  ze_kernel_handle_t zeKernel;
+  sycl::kernel *syclKernel;
+  ze_kernel_desc_t desc = {};
+  desc.pKernelName = name;
+
+  L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
+  sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
+      sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
+                               sycl::bundle_state::executable>({zeModule},
+                                                               syclContext);
+
+  auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
+      {kernelBundle, zeKernel}, syclContext);
+  syclKernel = new sycl::kernel(kernel);
+  return syclKernel;
+}
+
+static void launchKernel(QUEUE *queue, sycl::kernel *kernel, size_t gridX,
+                         size_t gridY, size_t gridZ, size_t blockX,
+                         size_t blockY, size_t blockZ, size_t sharedMemBytes,
+                         void **params, size_t paramsCount) {
+  auto syclGlobalRange =
+      ::sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
+  auto syclLocalRange = ::sycl::range<3>(blockZ, blockY, blockX);
+  sycl::nd_range<3> syclNdRange(
+      sycl::nd_range<3>(syclGlobalRange, syclLocalRange));
----------------
keryell wrote:

```suggestion
  sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
```

https://github.com/llvm/llvm-project/pull/65539
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to