yaxunl created this revision.
yaxunl added a reviewer: tra.
Herald added subscribers: dexonsmith, jdoerfert, hiraditya, mgorny.
Herald added a reviewer: aaron.ballman.
yaxunl requested review of this revision.
Herald added a project: LLVM.
Herald added a subscriber: llvm-commits.

This patch implements codegen for `__managed__` variable attribute for HIP.

Diagnostics will be added later.


https://reviews.llvm.org/D94814

Files:
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/AttrDocs.td
  clang/include/clang/Sema/SemaInternal.h
  clang/lib/CodeGen/CGCUDANV.cpp
  clang/lib/CodeGen/CGCUDARuntime.h
  clang/lib/CodeGen/CGDeclCXX.cpp
  clang/lib/CodeGen/CodeGenModule.cpp
  clang/lib/CodeGen/TargetInfo.cpp
  clang/lib/Sema/SemaCUDA.cpp
  clang/lib/Sema/SemaDeclAttr.cpp
  clang/lib/Sema/SemaExpr.cpp
  clang/test/CodeGenCUDA/Inputs/cuda.h
  clang/test/CodeGenCUDA/managed-var.cu
  clang/test/Misc/pragma-attribute-supported-attributes-list.test
  llvm/include/llvm/IR/ReplaceConstant.h
  llvm/lib/IR/CMakeLists.txt
  llvm/lib/IR/ReplaceConstant.cpp
  llvm/lib/Target/XCore/XCoreLowerThreadLocal.cpp

Index: llvm/lib/Target/XCore/XCoreLowerThreadLocal.cpp
===================================================================
--- llvm/lib/Target/XCore/XCoreLowerThreadLocal.cpp
+++ llvm/lib/Target/XCore/XCoreLowerThreadLocal.cpp
@@ -21,6 +21,7 @@
 #include "llvm/IR/IntrinsicsXCore.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/NoFolder.h"
+#include "llvm/IR/ReplaceConstant.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
@@ -74,57 +75,6 @@
   return ConstantArray::get(NewType, Elements);
 }
 
-static Instruction *
-createReplacementInstr(ConstantExpr *CE, Instruction *Instr) {
-  IRBuilder<NoFolder> Builder(Instr);
-  unsigned OpCode = CE->getOpcode();
-  switch (OpCode) {
-    case Instruction::GetElementPtr: {
-      SmallVector<Value *, 4> CEOpVec(CE->operands());
-      ArrayRef<Value *> CEOps(CEOpVec);
-      return dyn_cast<Instruction>(Builder.CreateInBoundsGEP(
-          cast<GEPOperator>(CE)->getSourceElementType(), CEOps[0],
-          CEOps.slice(1)));
-    }
-    case Instruction::Add:
-    case Instruction::Sub:
-    case Instruction::Mul:
-    case Instruction::UDiv:
-    case Instruction::SDiv:
-    case Instruction::FDiv:
-    case Instruction::URem:
-    case Instruction::SRem:
-    case Instruction::FRem:
-    case Instruction::Shl:
-    case Instruction::LShr:
-    case Instruction::AShr:
-    case Instruction::And:
-    case Instruction::Or:
-    case Instruction::Xor:
-      return dyn_cast<Instruction>(
-                  Builder.CreateBinOp((Instruction::BinaryOps)OpCode,
-                                      CE->getOperand(0), CE->getOperand(1),
-                                      CE->getName()));
-    case Instruction::Trunc:
-    case Instruction::ZExt:
-    case Instruction::SExt:
-    case Instruction::FPToUI:
-    case Instruction::FPToSI:
-    case Instruction::UIToFP:
-    case Instruction::SIToFP:
-    case Instruction::FPTrunc:
-    case Instruction::FPExt:
-    case Instruction::PtrToInt:
-    case Instruction::IntToPtr:
-    case Instruction::BitCast:
-      return dyn_cast<Instruction>(
-                  Builder.CreateCast((Instruction::CastOps)OpCode,
-                                     CE->getOperand(0), CE->getType(),
-                                     CE->getName()));
-    default:
-      llvm_unreachable("Unhandled constant expression!\n");
-  }
-}
 
 static bool replaceConstantExprOp(ConstantExpr *CE, Pass *P) {
   do {
Index: llvm/lib/IR/ReplaceConstant.cpp
===================================================================
--- /dev/null
+++ llvm/lib/IR/ReplaceConstant.cpp
@@ -0,0 +1,68 @@
+//===- ReplaceConstant.cpp - Replace LLVM constant expression--------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a utility function for replacing LLVM constant
+// expressions by instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/NoFolder.h"
+
+namespace llvm {
+Instruction *createReplacementInstr(ConstantExpr *CE, Instruction *Instr) {
+  IRBuilder<NoFolder> Builder(Instr);
+  unsigned OpCode = CE->getOpcode();
+  switch (OpCode) {
+  case Instruction::GetElementPtr: {
+    SmallVector<Value *, 4> CEOpVec(CE->operands());
+    ArrayRef<Value *> CEOps(CEOpVec);
+    return dyn_cast<Instruction>(
+        Builder.CreateInBoundsGEP(cast<GEPOperator>(CE)->getSourceElementType(),
+                                  CEOps[0], CEOps.slice(1)));
+  }
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::FDiv:
+  case Instruction::URem:
+  case Instruction::SRem:
+  case Instruction::FRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+    return dyn_cast<Instruction>(
+        Builder.CreateBinOp((Instruction::BinaryOps)OpCode, CE->getOperand(0),
+                            CE->getOperand(1), CE->getName()));
+  case Instruction::Trunc:
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::FPToUI:
+  case Instruction::FPToSI:
+  case Instruction::UIToFP:
+  case Instruction::SIToFP:
+  case Instruction::FPTrunc:
+  case Instruction::FPExt:
+  case Instruction::PtrToInt:
+  case Instruction::IntToPtr:
+  case Instruction::BitCast:
+    return dyn_cast<Instruction>(
+        Builder.CreateCast((Instruction::CastOps)OpCode, CE->getOperand(0),
+                           CE->getType(), CE->getName()));
+  default:
+    llvm_unreachable("Unhandled constant expression!\n");
+  }
+}
+} // namespace llvm
Index: llvm/lib/IR/CMakeLists.txt
===================================================================
--- llvm/lib/IR/CMakeLists.txt
+++ llvm/lib/IR/CMakeLists.txt
@@ -49,6 +49,7 @@
   SafepointIRVerifier.cpp
   ProfileSummary.cpp
   PseudoProbe.cpp
+  ReplaceConstant.cpp
   Statepoint.cpp
   StructuralHash.cpp
   Type.cpp
Index: llvm/include/llvm/IR/ReplaceConstant.h
===================================================================
--- /dev/null
+++ llvm/include/llvm/IR/ReplaceConstant.h
@@ -0,0 +1,28 @@
+//===- ReplaceConstant.h - Replacing LLVM constant expressions --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the utility function for replacing LLVM constant
+// expressions by instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_REPLACECONSTANT_H
+#define LLVM_IR_REPLACECONSTANT_H
+
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instruction.h"
+
+namespace llvm {
+
+/// Create a replacement instruction for constant expression \p CE and insert
+/// it before \p Instr.
+Instruction *createReplacementInstr(ConstantExpr *CE, Instruction *Instr);
+
+} // end namespace llvm
+
+#endif // LLVM_IR_REPLACECONSTANT_H
Index: clang/test/Misc/pragma-attribute-supported-attributes-list.test
===================================================================
--- clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -64,6 +64,7 @@
 // CHECK-NEXT: FlagEnum (SubjectMatchRule_enum)
 // CHECK-NEXT: Flatten (SubjectMatchRule_function)
 // CHECK-NEXT: GNUInline (SubjectMatchRule_function)
+// CHECK-NEXT: HIPManaged (SubjectMatchRule_variable)
 // CHECK-NEXT: Hot (SubjectMatchRule_function)
 // CHECK-NEXT: IBAction (SubjectMatchRule_objc_method_is_instance)
 // CHECK-NEXT: IFunc (SubjectMatchRule_function)
Index: clang/test/CodeGenCUDA/managed-var.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/managed-var.cu
@@ -0,0 +1,94 @@
+// REQUIRES: x86-registered-target, amdgpu-registered-target
+
+// RUN: %clang_cc1 -triple amdgcn-amd-amdhsa -fcuda-is-device -std=c++11 \
+// RUN:   -emit-llvm -o - -x hip %s | FileCheck \
+// RUN:   -check-prefixes=DEV %s
+
+// RUN: %clang_cc1 -triple x86_64-gnu-linux -std=c++11 \
+// RUN:   -emit-llvm -o - -x hip %s | FileCheck \
+// RUN:   -check-prefixes=HOST %s
+
+#include "Inputs/cuda.h"
+
+
+// Check a static device variable referenced by host function is externalized.
+// DEV-DAG: @x = dso_local addrspace(1) externally_initialized global i32 undef
+// HOST-DAG: @x = internal global i32 1
+// HOST-DAG: @x.managed = external global i32*
+// HOST-DAG: @[[DEVNAMEX:[0-9]+]] = {{.*}}c"x\00"
+
+struct vec {
+  float x,y,z;
+};
+
+__managed__ int x = 1;
+__managed__ vec v[100];
+__managed__ vec v2[100] = {{1, 1, 1}};
+
+__global__ void foo(int *z) {
+  *z = x;
+  v[1].x = 2;
+}
+
+// HOST-LABEL: define {{.*}}@_Z4loadv()
+// HOST:  %ld.managed = load i32*, i32** @x.managed, align 4
+// HOST:  %0 = load i32, i32* %ld.managed, align 4
+// HOST:  ret i32 %0
+int load() {
+  return x;
+}
+
+// HOST-LABEL: define {{.*}}@_Z5storev()
+// HOST:  %ld.managed = load i32*, i32** @x.managed, align 4
+// HOST:  store i32 2, i32* %ld.managed, align 4
+void store() {
+  x = 2;
+}
+
+// HOST-LABEL: define {{.*}}@_Z10addr_takenv()
+// HOST:  %ld.managed = load i32*, i32** @x.managed, align 4
+// HOST:  store i32* %ld.managed, i32** %p, align 8
+// HOST:  %0 = load i32*, i32** %p, align 8
+// HOST:  store i32 3, i32* %0, align 4
+void addr_taken() {
+  int *p = &x;
+  *p = 3;
+}
+
+// HOST-LABEL: define {{.*}}@_Z5load2v()
+// HOST: %ld.managed = load [100 x %struct.vec]*, [100 x %struct.vec]** @v.managed, align 16
+// HOST:  %0 = getelementptr inbounds [100 x %struct.vec], [100 x %struct.vec]* %ld.managed, i64 0, i64 1, i32 0
+// HOST:  %1 = load float, float* %0, align 4
+// HOST:  ret float %1
+float load2() {
+  return v[1].x;
+}
+
+// HOST-LABEL: define {{.*}}@_Z5load3v()
+// HOST:  %ld.managed = load <{ %struct.vec, [99 x %struct.vec] }>*, <{ %struct.vec, [99 x %struct.vec] }>** @v2.managed, align 16
+// HOST:  %0 = bitcast <{ %struct.vec, [99 x %struct.vec] }>* %ld.managed to [100 x %struct.vec]*
+// HOST:  %1 = getelementptr inbounds [100 x %struct.vec], [100 x %struct.vec]* %0, i64 0, i64 1, i32 1
+// HOST:  %2 = load float, float* %1, align 4
+// HOST:  ret float %2
+float load3() {
+  return v2[1].y;
+}
+
+// HOST-LABEL: define {{.*}}@_Z11addr_taken2v()
+// HOST:  %ld.managed = load [100 x %struct.vec]*, [100 x %struct.vec]** @v.managed, align 16
+// HOST:  %0 = getelementptr inbounds [100 x %struct.vec], [100 x %struct.vec]* %ld.managed, i64 0, i64 1, i32 0
+// HOST:  %1 = ptrtoint float* %0 to i64
+// HOST:  %ld.managed1 = load <{ %struct.vec, [99 x %struct.vec] }>*, <{ %struct.vec, [99 x %struct.vec] }>** @v2.managed, align 16
+// HOST:  %2 = bitcast <{ %struct.vec, [99 x %struct.vec] }>* %ld.managed1 to [100 x %struct.vec]*
+// HOST:  %3 = getelementptr inbounds [100 x %struct.vec], [100 x %struct.vec]* %2, i64 0, i64 1, i32 1
+// HOST:  %4 = ptrtoint float* %3 to i64
+// HOST:  %5 = sub i64 %4, %1
+// HOST:  %6 = sdiv i64 %5, 4
+// HOST:  %7 = sitofp i64 %6 to float
+// HOST:  ret float %7
+float addr_taken2() {
+  return (float)reinterpret_cast<long>(&(v2[1].y)-&(v[1].x));
+}
+
+// HOST-DAG: __hipRegisterManagedVar({{.*}}@x.managed {{.*}}@x {{.*}}@[[DEVNAMEX]]{{.*}}, i64 4, i32 4)
+// HOST-DAG: declare void @__hipRegisterManagedVar(i8**, i8*, i8*, i8*, i64, i32)
Index: clang/test/CodeGenCUDA/Inputs/cuda.h
===================================================================
--- clang/test/CodeGenCUDA/Inputs/cuda.h
+++ clang/test/CodeGenCUDA/Inputs/cuda.h
@@ -7,6 +7,9 @@
 #define __global__ __attribute__((global))
 #define __host__ __attribute__((host))
 #define __shared__ __attribute__((shared))
+#if __HIP__
+#define __managed__ __attribute__((managed))
+#endif
 #define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__)))
 
 struct dim3 {
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -364,6 +364,7 @@
       const auto *VD = dyn_cast<VarDecl>(D);
       if (VD && VD->hasGlobalStorage() && !VD->hasAttr<CUDADeviceAttr>() &&
           !VD->hasAttr<CUDAConstantAttr>() && !VD->hasAttr<CUDASharedAttr>() &&
+          !VD->hasAttr<HIPManagedAttr>() &&
           !VD->getType()->isCUDADeviceBuiltinSurfaceType() &&
           !VD->getType()->isCUDADeviceBuiltinTextureType() &&
           !VD->isConstexpr() && !VD->getType().isConstQualified())
Index: clang/lib/Sema/SemaDeclAttr.cpp
===================================================================
--- clang/lib/Sema/SemaDeclAttr.cpp
+++ clang/lib/Sema/SemaDeclAttr.cpp
@@ -7724,6 +7724,10 @@
   case ParsedAttr::AT_CUDAHost:
     handleSimpleAttributeWithExclusions<CUDAHostAttr, CUDAGlobalAttr>(S, D, AL);
     break;
+  case ParsedAttr::AT_HIPManaged:
+    handleSimpleAttributeWithExclusions<HIPManagedAttr, CUDAGlobalAttr>(S, D,
+                                                                        AL);
+    break;
   case ParsedAttr::AT_CUDADeviceBuiltinSurfaceType:
     handleSimpleAttributeWithExclusions<CUDADeviceBuiltinSurfaceTypeAttr,
                                         CUDADeviceBuiltinTextureTypeAttr>(S, D,
Index: clang/lib/Sema/SemaCUDA.cpp
===================================================================
--- clang/lib/Sema/SemaCUDA.cpp
+++ clang/lib/Sema/SemaCUDA.cpp
@@ -516,7 +516,7 @@
     return;
   const Expr *Init = VD->getInit();
   if (VD->hasAttr<CUDADeviceAttr>() || VD->hasAttr<CUDAConstantAttr>() ||
-      VD->hasAttr<CUDASharedAttr>()) {
+      VD->hasAttr<CUDASharedAttr>() || VD->hasAttr<HIPManagedAttr>()) {
     if (LangOpts.GPUAllowDeviceInit)
       return;
     bool AllowedInit = false;
@@ -527,7 +527,8 @@
     // constructor according to CUDA rules. This deviates from NVCC,
     // but allows us to handle things like constexpr constructors.
     if (!AllowedInit &&
-        (VD->hasAttr<CUDADeviceAttr>() || VD->hasAttr<CUDAConstantAttr>())) {
+        (VD->hasAttr<CUDADeviceAttr>() || VD->hasAttr<CUDAConstantAttr>() ||
+         VD->hasAttr<HIPManagedAttr>())) {
       auto *Init = VD->getInit();
       AllowedInit =
           ((VD->getType()->isDependentType() || Init->isValueDependent()) &&
Index: clang/lib/CodeGen/TargetInfo.cpp
===================================================================
--- clang/lib/CodeGen/TargetInfo.cpp
+++ clang/lib/CodeGen/TargetInfo.cpp
@@ -8958,6 +8958,7 @@
          (isa<FunctionDecl>(D) && D->hasAttr<CUDAGlobalAttr>()) ||
          (isa<VarDecl>(D) &&
           (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>() ||
+           D->hasAttr<HIPManagedAttr>() ||
            cast<VarDecl>(D)->getType()->isCUDADeviceBuiltinSurfaceType() ||
            cast<VarDecl>(D)->getType()->isCUDADeviceBuiltinTextureType()));
 }
Index: clang/lib/CodeGen/CodeGenModule.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenModule.cpp
+++ clang/lib/CodeGen/CodeGenModule.cpp
@@ -2740,6 +2740,7 @@
           !Global->hasAttr<CUDAGlobalAttr>() &&
           !Global->hasAttr<CUDAConstantAttr>() &&
           !Global->hasAttr<CUDASharedAttr>() &&
+          !Global->hasAttr<HIPManagedAttr>() &&
           !Global->getType()->isCUDADeviceBuiltinSurfaceType() &&
           !Global->getType()->isCUDADeviceBuiltinTextureType())
         return;
@@ -3993,7 +3994,8 @@
       return LangAS::cuda_constant;
     else if (D && D->hasAttr<CUDASharedAttr>())
       return LangAS::cuda_shared;
-    else if (D && D->hasAttr<CUDADeviceAttr>())
+    else if (D &&
+             (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<HIPManagedAttr>()))
       return LangAS::cuda_device;
     else if (D && D->getType().isConstQualified())
       return LangAS::cuda_constant;
@@ -4151,7 +4153,8 @@
   bool IsCUDADeviceShadowVar =
       getLangOpts().CUDAIsDevice &&
       (D->getType()->isCUDADeviceBuiltinSurfaceType() ||
-       D->getType()->isCUDADeviceBuiltinTextureType());
+       D->getType()->isCUDADeviceBuiltinTextureType() ||
+       D->hasAttr<HIPManagedAttr>());
   // HIP pinned shadow of initialized host-side global variables are also
   // left undefined.
   if (getLangOpts().CUDA &&
@@ -4256,14 +4259,16 @@
   if (GV && LangOpts.CUDA) {
     if (LangOpts.CUDAIsDevice) {
       if (Linkage != llvm::GlobalValue::InternalLinkage &&
-          (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>()))
+          (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>() ||
+           D->hasAttr<HIPManagedAttr>()))
         GV->setExternallyInitialized(true);
     } else {
       // Host-side shadows of external declarations of device-side
       // global variables become internal definitions. These have to
       // be internal in order to prevent name conflicts with global
       // host variables with the same name in a different TUs.
-      if (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>()) {
+      if (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>() ||
+          D->hasAttr<HIPManagedAttr>()) {
         Linkage = llvm::GlobalValue::InternalLinkage;
         // Shadow variables and their properties must be registered with CUDA
         // runtime. Skip Extern global variables, which will be registered in
Index: clang/lib/CodeGen/CGDeclCXX.cpp
===================================================================
--- clang/lib/CodeGen/CGDeclCXX.cpp
+++ clang/lib/CodeGen/CGDeclCXX.cpp
@@ -460,7 +460,7 @@
   // are allowed are empty and we just need to ignore them here.
   if (getLangOpts().CUDAIsDevice && !getLangOpts().GPUAllowDeviceInit &&
       (D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDAConstantAttr>() ||
-       D->hasAttr<CUDASharedAttr>()))
+       D->hasAttr<CUDASharedAttr>() || D->hasAttr<HIPManagedAttr>()))
     return;
 
   if (getLangOpts().OpenMP &&
Index: clang/lib/CodeGen/CGCUDARuntime.h
===================================================================
--- clang/lib/CodeGen/CGCUDARuntime.h
+++ clang/lib/CodeGen/CGCUDARuntime.h
@@ -54,16 +54,19 @@
     unsigned Kind : 2;
     unsigned Extern : 1;
     unsigned Constant : 1;   // Constant variable.
+    unsigned Managed : 1;    // Managed variable.
     unsigned Normalized : 1; // Normalized texture.
     int SurfTexType;         // Type of surface/texutre.
 
   public:
-    DeviceVarFlags(DeviceVarKind K, bool E, bool C, bool N, int T)
-        : Kind(K), Extern(E), Constant(C), Normalized(N), SurfTexType(T) {}
+    DeviceVarFlags(DeviceVarKind K, bool E, bool C, bool M, bool N, int T)
+        : Kind(K), Extern(E), Constant(C), Managed(M), Normalized(N),
+          SurfTexType(T) {}
 
     DeviceVarKind getKind() const { return static_cast<DeviceVarKind>(Kind); }
     bool isExtern() const { return Extern; }
     bool isConstant() const { return Constant; }
+    bool isManaged() const { return Managed; }
     bool isNormalized() const { return Normalized; }
     int getSurfTexType() const { return SurfTexType; }
   };
Index: clang/lib/CodeGen/CGCUDANV.cpp
===================================================================
--- clang/lib/CodeGen/CGCUDANV.cpp
+++ clang/lib/CodeGen/CGCUDANV.cpp
@@ -21,6 +21,7 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/ReplaceConstant.h"
 #include "llvm/Support/Format.h"
 
 using namespace clang;
@@ -128,13 +129,15 @@
     DeviceVars.push_back({&Var,
                           VD,
                           {DeviceVarFlags::Variable, Extern, Constant,
-                           /*Normalized*/ false, /*Type*/ 0}});
+                           VD->hasAttr<HIPManagedAttr>(),
+                           /*Normalized*/ false, 0}});
   }
   void registerDeviceSurf(const VarDecl *VD, llvm::GlobalVariable &Var,
                           bool Extern, int Type) override {
     DeviceVars.push_back({&Var,
                           VD,
                           {DeviceVarFlags::Surface, Extern, /*Constant*/ false,
+                           /*Managed*/ false,
                            /*Normalized*/ false, Type}});
   }
   void registerDeviceTex(const VarDecl *VD, llvm::GlobalVariable &Var,
@@ -142,7 +145,7 @@
     DeviceVars.push_back({&Var,
                           VD,
                           {DeviceVarFlags::Texture, Extern, /*Constant*/ false,
-                           Normalized, Type}});
+                           /*Managed*/ false, Normalized, Type}});
   }
 
   /// Creates module constructor function
@@ -380,6 +383,47 @@
   CGF.EmitBlock(EndBlock);
 }
 
+// Replace the original variable Var with the address loaded from variable
+// ManagedVar populated by HIP runtime.
+static void replaceManagedVar(llvm::GlobalVariable *Var,
+                              llvm::GlobalVariable *ManagedVar) {
+  SmallVector<SmallVector<llvm::User *, 8>, 8> WorkList;
+  for (auto &&UU : Var->uses()) {
+    WorkList.push_back({UU.getUser()});
+  }
+  while (!WorkList.empty()) {
+    auto &&US = WorkList.pop_back_val();
+    auto *U = US.back();
+    if (isa<llvm::ConstantExpr>(U)) {
+      for (auto &&UU : U->uses()) {
+        US.push_back(UU.getUser());
+        WorkList.push_back(US);
+        US.pop_back();
+      }
+      continue;
+    }
+    if (auto *I = dyn_cast<llvm::Instruction>(U)) {
+      llvm::Value *OldV = Var;
+      llvm::Instruction *NewV =
+          new llvm::LoadInst(Var->getType(), ManagedVar, "ld.managed", false,
+                             llvm::Align(Var->getAlignment()), I);
+      US.pop_back();
+      // Replace constant expressions directly or indirectly using the managed
+      // variable with instructions.
+      for (auto &&Op : US) {
+        auto *CE = cast<llvm::ConstantExpr>(Op);
+        auto *NewInst = llvm::createReplacementInstr(CE, I);
+        NewInst->replaceUsesOfWith(OldV, NewV);
+        OldV = CE;
+        NewV = NewInst;
+      }
+      I->replaceUsesOfWith(OldV, NewV);
+    } else {
+      llvm_unreachable("Invalid use of managed variable");
+    }
+  }
+}
+
 /// Creates a function that sets up state on the host side for CUDA objects that
 /// have a presence on both the host and device sides. Specifically, registers
 /// the host side of kernel functions and device global variables with the CUDA
@@ -452,6 +496,13 @@
   llvm::FunctionCallee RegisterVar = CGM.CreateRuntimeFunction(
       llvm::FunctionType::get(VoidTy, RegisterVarParams, false),
       addUnderscoredPrefixToName("RegisterVar"));
+  // void __hipRegisterManagedVar(void **, char *, char *, const char *,
+  //                              size_t, unsigned)
+  llvm::Type *RegisterManagedVarParams[] = {VoidPtrPtrTy, CharPtrTy, CharPtrTy,
+                                            CharPtrTy,    VarSizeTy, IntTy};
+  llvm::FunctionCallee RegisterManagedVar = CGM.CreateRuntimeFunction(
+      llvm::FunctionType::get(VoidTy, RegisterManagedVarParams, false),
+      addUnderscoredPrefixToName("RegisterManagedVar"));
   // void __cudaRegisterSurface(void **, const struct surfaceReference *,
   //                            const void **, const char *, int, int);
   llvm::FunctionCallee RegisterSurf = CGM.CreateRuntimeFunction(
@@ -474,16 +525,34 @@
     case DeviceVarFlags::Variable: {
       uint64_t VarSize =
           CGM.getDataLayout().getTypeAllocSize(Var->getValueType());
-      llvm::Value *Args[] = {
-          &GpuBinaryHandlePtr,
-          Builder.CreateBitCast(Var, VoidPtrTy),
-          VarName,
-          VarName,
-          llvm::ConstantInt::get(IntTy, Info.Flags.isExtern()),
-          llvm::ConstantInt::get(VarSizeTy, VarSize),
-          llvm::ConstantInt::get(IntTy, Info.Flags.isConstant()),
-          llvm::ConstantInt::get(IntTy, 0)};
-      Builder.CreateCall(RegisterVar, Args);
+      if (Info.Flags.isManaged()) {
+        auto ManagedVar = new llvm::GlobalVariable(
+            CGM.getModule(), Var->getType(),
+            /*isConstant=*/false, llvm::GlobalValue::ExternalLinkage, nullptr,
+            Twine(Var->getName() + ".managed"), nullptr,
+            llvm::GlobalVariable::NotThreadLocal);
+        replaceManagedVar(Var, ManagedVar);
+        llvm::Value *Args[] = {
+            &GpuBinaryHandlePtr,
+            Builder.CreateBitCast(ManagedVar, VoidPtrTy),
+            Builder.CreateBitCast(Var, VoidPtrTy),
+            VarName,
+            llvm::ConstantInt::get(VarSizeTy, VarSize),
+            llvm::ConstantInt::get(IntTy, Var->getAlignment())};
+        Builder.CreateCall(RegisterManagedVar, Args);
+      } else {
+        llvm::Value *Args[] = {
+            &GpuBinaryHandlePtr,
+            Builder.CreateBitCast(Var, VoidPtrTy),
+            VarName,
+            VarName,
+            llvm::ConstantInt::get(IntTy, Info.Flags.isExtern()),
+            llvm::ConstantInt::get(VarSizeTy, VarSize),
+            llvm::ConstantInt::get(IntTy, Info.Flags.isConstant()),
+            llvm::ConstantInt::get(IntTy, CGM.getLangOpts().HIP &&
+                                              Info.Flags.isManaged())};
+        Builder.CreateCall(RegisterVar, Args);
+      }
       break;
     }
     case DeviceVarFlags::Surface:
Index: clang/include/clang/Sema/SemaInternal.h
===================================================================
--- clang/include/clang/Sema/SemaInternal.h
+++ clang/include/clang/Sema/SemaInternal.h
@@ -44,9 +44,9 @@
 inline bool DeclAttrsMatchCUDAMode(const LangOptions &LangOpts, Decl *D) {
   if (!LangOpts.CUDA || !D)
     return true;
-  bool isDeviceSideDecl = D->hasAttr<CUDADeviceAttr>() ||
-                          D->hasAttr<CUDASharedAttr>() ||
-                          D->hasAttr<CUDAGlobalAttr>();
+  bool isDeviceSideDecl =
+      D->hasAttr<CUDADeviceAttr>() || D->hasAttr<CUDASharedAttr>() ||
+      D->hasAttr<CUDAGlobalAttr>() || D->hasAttr<HIPManagedAttr>();
   return isDeviceSideDecl == LangOpts.CUDAIsDevice;
 }
 
Index: clang/include/clang/Basic/AttrDocs.td
===================================================================
--- clang/include/clang/Basic/AttrDocs.td
+++ clang/include/clang/Basic/AttrDocs.td
@@ -5419,6 +5419,17 @@
   }];
 }
 
+def HIPManagedAttrDocs : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``__managed__`` attribute can be applied to a global variable declaration in HIP.
+A managed variable is emitted as a undefined global symbol in device binary and
+registered by ``__hipRegisterManagedVariable`` in init functions. HIP runtime allocates
+managed memory and use it to define the symbol when loading the device binary.
+A managed variable can be accessed in both device and host code.
+  }];
+}
+
 def LifetimeOwnerDocs : Documentation {
   let Category = DocCatDecl;
   let Content = [{
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -324,6 +324,7 @@
 def MicrosoftExt : LangOpt<"MicrosoftExt">;
 def Borland : LangOpt<"Borland">;
 def CUDA : LangOpt<"CUDA">;
+def HIP : LangOpt<"HIP">;
 def SYCL : LangOpt<"SYCLIsDevice">;
 def COnly : LangOpt<"", "!LangOpts.CPlusPlus">;
 def CPlusPlus : LangOpt<"CPlusPlus">;
@@ -1108,6 +1109,13 @@
   let Documentation = [Undocumented];
 }
 
+def HIPManaged : InheritableAttr {
+  let Spellings = [GNU<"managed">, Declspec<"__managed__">];
+  let Subjects = SubjectList<[Var]>;
+  let LangOpts = [HIP];
+  let Documentation = [HIPManagedAttrDocs];
+}
+
 def CUDAInvalidTarget : InheritableAttr {
   let Spellings = [];
   let Subjects = SubjectList<[Function]>;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to