hliao created this revision.
hliao added a reviewer: tra.
Herald added subscribers: llvm-commits, cfe-commits, arphaman, hiraditya, 
yaxunl, mgorny, jholewinski.
Herald added projects: clang, LLVM.
hliao requested review of this revision.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D91590

Files:
  clang/lib/CodeGen/TargetInfo.cpp
  clang/test/CodeGen/nvptx-abi.c
  clang/test/CodeGenCUDA/kernel-args-alignment.cu
  clang/test/CodeGenCUDA/kernel-args.cu
  clang/test/OpenMP/nvptx_unsupported_type_codegen.cpp
  llvm/lib/Target/NVPTX/CMakeLists.txt
  llvm/lib/Target/NVPTX/NVPTX.h
  llvm/lib/Target/NVPTX/NVPTXAA.cpp
  llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
  llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Index: llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -198,11 +198,19 @@
 
 void NVPTXTargetMachine::adjustPassManager(PassManagerBuilder &Builder) {
   Builder.addExtension(
-    PassManagerBuilder::EP_EarlyAsPossible,
-    [&](const PassManagerBuilder &, legacy::PassManagerBase &PM) {
-      PM.add(createNVVMReflectPass(Subtarget.getSmVersion()));
-      PM.add(createNVVMIntrRangePass(Subtarget.getSmVersion()));
-    });
+      PassManagerBuilder::EP_EarlyAsPossible,
+      [&](const PassManagerBuilder &, legacy::PassManagerBase &PM) {
+        PM.add(createNVPTXAAWrapperPass());
+        PM.add(createNVPTXExternalAAWrapperPass());
+        PM.add(createNVVMReflectPass(Subtarget.getSmVersion()));
+        PM.add(createNVVMIntrRangePass(Subtarget.getSmVersion()));
+      });
+  Builder.addExtension(
+      PassManagerBuilder::EP_ModuleOptimizerEarly,
+      [&](const PassManagerBuilder &, legacy::PassManagerBase &PM) {
+        PM.add(createNVPTXAAWrapperPass());
+        PM.add(createNVPTXExternalAAWrapperPass());
+      });
 }
 
 TargetTransformInfo
@@ -279,6 +287,9 @@
     addStraightLineScalarOptimizationPasses();
   }
 
+  addPass(createNVPTXAAWrapperPass());
+  addPass(createNVPTXExternalAAWrapperPass());
+
   // === LSR and other generic IR passes ===
   TargetPassConfig::addIRPasses();
   // EarlyCSE is not always strong enough to clean up what LSR produces. For
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2531,7 +2531,8 @@
     // to newly created nodes. The SDNodes for params have to
     // appear in the same order as their order of appearance
     // in the original function. "idx+1" holds that order.
-    if (!PAL.hasParamAttribute(i, Attribute::ByVal)) {
+    if (!PAL.hasParamAttribute(i, Attribute::ByVal) &&
+        !PAL.hasParamAttribute(i, Attribute::ByRef)) {
       bool aggregateIsPacked = false;
       if (StructType *STy = dyn_cast<StructType>(Ty))
         aggregateIsPacked = STy->isPacked();
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -1457,7 +1457,8 @@
       }
     }
 
-    if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal)) {
+    if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal) &&
+        !PAL.hasParamAttribute(paramIndex, Attribute::ByRef)) {
       if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
         // Just print .param .align <a> .b8 .param[size];
         // <a> = PAL.getparamalignment
Index: llvm/lib/Target/NVPTX/NVPTXAA.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/NVPTX/NVPTXAA.cpp
@@ -0,0 +1,131 @@
+#include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTX.h"
+#include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Module.h"
+
+namespace llvm {
+void initializeNVPTXAAWrapperPass(PassRegistry &);
+void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
+} // namespace llvm
+
+#define DEBUG_TYPE "nvptx-aa"
+
+using namespace llvm;
+
+namespace {
+
+class NVPTXAAResult : public AAResultBase<NVPTXAAResult> {
+  friend AAResultBase<NVPTXAAResult>;
+
+public:
+  explicit NVPTXAAResult() : AAResultBase() {}
+  NVPTXAAResult(NVPTXAAResult &&Arg) : AAResultBase(std::move(Arg)) {}
+
+  bool invalidate(Function &F, const PreservedAnalyses &PA,
+                  FunctionAnalysisManager::Invalidator &Inv);
+
+  AliasResult alias(const MemoryLocation &LocA, const MemoryLocation &LocB,
+                    AAQueryInfo &AAQI) {
+    MemoryLocation L1 = LocA;
+    MemoryLocation L2 = LocB;
+    unsigned AS1 = L1.Ptr->getType()->getPointerAddressSpace();
+    unsigned AS2 = L2.Ptr->getType()->getPointerAddressSpace();
+    if (AS1 != ADDRESS_SPACE_GENERIC) {
+      std::swap(L1, L2);
+      std::swap(AS1, AS2);
+    }
+    if (AS1 == ADDRESS_SPACE_GENERIC) {
+      const auto *O1 =
+          getUnderlyingObject(L1.Ptr->stripPointerCastsAndInvariantGroups());
+      AS1 = O1->getType()->getPointerAddressSpace();
+      if (AS2 == ADDRESS_SPACE_GENERIC) {
+        const auto *O2 =
+            getUnderlyingObject(L1.Ptr->stripPointerCastsAndInvariantGroups());
+        AS2 = O2->getType()->getPointerAddressSpace();
+      }
+      if (AS1 == ADDRESS_SPACE_PARAM || AS2 == ADDRESS_SPACE_PARAM) {
+        if (AS1 != AS2)
+          return NoAlias;
+        // Fallback to the next alias analysis.
+      } else if (AS1 != ADDRESS_SPACE_GENERIC && AS2 != ADDRESS_SPACE_GENERIC) {
+        if (AS1 != AS2)
+          return NoAlias;
+        // Fallback to the next alias analysis.
+      }
+    } else if (AS1 != AS2) {
+      return NoAlias;
+      // Fallback to the next alias analysis.
+    }
+    // Query the next alias analysis.
+    return AAResultBase::alias(LocA, LocB, AAQI);
+  }
+
+  bool pointsToConstantMemory(const MemoryLocation &Loc, AAQueryInfo &AAQI,
+                              bool OrLocal) {
+    unsigned AS = Loc.Ptr->getType()->getPointerAddressSpace();
+    // According to PTX ISA section 5.1.6.4, ``Function input parameters may be
+    // read via `ld.param` and function return parameters may be written using
+    // `st.param`; it is illegal to write to an input parameter or read from a
+    // return parameter.'' It's safe to assume that parameter memory space is
+    // constant.
+    if (AS == ADDRESS_SPACE_CONST || AS == ADDRESS_SPACE_PARAM)
+      return true;
+    return AAResultBase::pointsToConstantMemory(Loc, AAQI, OrLocal);
+  }
+};
+
+class NVPTXAAWrapper : public ImmutablePass {
+  std::unique_ptr<NVPTXAAResult> Result;
+
+public:
+  static char ID;
+
+  NVPTXAAWrapper() : ImmutablePass(ID) {
+    initializeNVPTXAAWrapperPass(*PassRegistry::getPassRegistry());
+  }
+
+  NVPTXAAResult &getResult() { return *Result; }
+  const NVPTXAAResult &getResult() const { return *Result; }
+
+  bool doInitialization(Module &M) override {
+    Result.reset(new NVPTXAAResult());
+    return false;
+  }
+
+  bool doFinalization(Module &M) override {
+    Result.reset();
+    return false;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+  }
+};
+
+class NVPTXExternalAAWrapper : public ExternalAAWrapperPass {
+public:
+  static char ID;
+
+  NVPTXExternalAAWrapper()
+      : ExternalAAWrapperPass([](Pass &P, Function &F, AAResults &AAR) {
+          if (auto *WrapperPass = P.getAnalysisIfAvailable<NVPTXAAWrapper>())
+            AAR.addAAResult(WrapperPass->getResult());
+        }) {}
+};
+
+} // End of anonymous namespace
+
+char NVPTXAAWrapper::ID = 0;
+char NVPTXExternalAAWrapper::ID = 0;
+
+INITIALIZE_PASS(NVPTXAAWrapper, DEBUG_TYPE, "NVPTX AA Wrapper", true, true)
+INITIALIZE_PASS(NVPTXExternalAAWrapper, "nvptx-external-aa-wrapper",
+                "NVPTX ExternalAA Wrapper", true, true)
+
+ImmutablePass *llvm::createNVPTXAAWrapperPass() { return new NVPTXAAWrapper(); }
+
+ImmutablePass *llvm::createNVPTXExternalAAWrapperPass() {
+  return new NVPTXExternalAAWrapper();
+}
Index: llvm/lib/Target/NVPTX/NVPTX.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTX.h
+++ llvm/lib/Target/NVPTX/NVPTX.h
@@ -46,6 +46,8 @@
 FunctionPass *createNVPTXLowerAllocaPass();
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
+ImmutablePass *createNVPTXAAWrapperPass();
+ImmutablePass *createNVPTXExternalAAWrapperPass();
 
 namespace NVPTX {
 enum DrvInterface {
Index: llvm/lib/Target/NVPTX/CMakeLists.txt
===================================================================
--- llvm/lib/Target/NVPTX/CMakeLists.txt
+++ llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -11,6 +11,7 @@
 add_public_tablegen_target(NVPTXCommonTableGen)
 
 set(NVPTXCodeGen_sources
+  NVPTXAA.cpp
   NVPTXAllocaHoisting.cpp
   NVPTXAsmPrinter.cpp
   NVPTXAssignValidGlobalNames.cpp
Index: clang/test/OpenMP/nvptx_unsupported_type_codegen.cpp
===================================================================
--- clang/test/OpenMP/nvptx_unsupported_type_codegen.cpp
+++ clang/test/OpenMP/nvptx_unsupported_type_codegen.cpp
@@ -34,7 +34,7 @@
 #pragma omp declare target
 T a = T();
 T f = a;
-// CHECK: define{{ hidden | }}void @{{.+}}foo{{.+}}([[T]]* byval([[T]]) align {{.+}})
+// CHECK: define{{ hidden | }}void @{{.+}}foo{{.+}}([[T]] addrspace(101)* byref([[T]]) align {{.+}})
 void foo(T a = T()) {
   return;
 }
@@ -54,7 +54,7 @@
 }
 T1 a1 = T1();
 T1 f1 = a1;
-// CHECK: define{{ hidden | }}void @{{.+}}foo1{{.+}}([[T1]]* byval([[T1]]) align {{.+}})
+// CHECK: define{{ hidden | }}void @{{.+}}foo1{{.+}}([[T1]] addrspace(101)* byref([[T1]]) align {{.+}})
 void foo1(T1 a = T1()) {
   return;
 }
Index: clang/test/CodeGenCUDA/kernel-args.cu
===================================================================
--- clang/test/CodeGenCUDA/kernel-args.cu
+++ clang/test/CodeGenCUDA/kernel-args.cu
@@ -10,14 +10,14 @@
 };
 
 // AMDGCN: define amdgpu_kernel void @_Z6kernel1A(%struct.A addrspace(4)* byref(%struct.A) align 8 %{{.+}})
-// NVPTX: define void @_Z6kernel1A(%struct.A* byval(%struct.A) align 8 %x)
+// NVPTX: define void @_Z6kernel1A(%struct.A addrspace(101)* byref(%struct.A) align 8 %0)
 __global__ void kernel(A x) {
 }
 
 class Kernel {
 public:
   // AMDGCN: define amdgpu_kernel void @_ZN6Kernel12memberKernelE1A(%struct.A addrspace(4)* byref(%struct.A) align 8 %{{.+}})
-  // NVPTX: define void @_ZN6Kernel12memberKernelE1A(%struct.A* byval(%struct.A) align 8 %x)
+  // NVPTX: define void @_ZN6Kernel12memberKernelE1A(%struct.A addrspace(101)* byref(%struct.A) align 8 %0)
   static __global__ void memberKernel(A x){}
   template<typename T> static __global__ void templateMemberKernel(T x) {}
 };
@@ -31,10 +31,10 @@
 void test() {
   Kernel K;
   // AMDGCN: define amdgpu_kernel void @_Z14templateKernelI1AEvT_(%struct.A addrspace(4)* byref(%struct.A) align 8 %{{.+}}
-  // NVPTX: define void @_Z14templateKernelI1AEvT_(%struct.A* byval(%struct.A) align 8 %x)
+  // NVPTX: define void @_Z14templateKernelI1AEvT_(%struct.A addrspace(101)* byref(%struct.A) align 8 %0)
   launch((void*)templateKernel<A>);
 
   // AMDGCN: define amdgpu_kernel void @_ZN6Kernel20templateMemberKernelI1AEEvT_(%struct.A addrspace(4)* byref(%struct.A) align 8 %{{.+}}
-  // NVPTX: define void @_ZN6Kernel20templateMemberKernelI1AEEvT_(%struct.A* byval(%struct.A) align 8 %x)
+  // NVPTX: define void @_ZN6Kernel20templateMemberKernelI1AEEvT_(%struct.A addrspace(101)* byref(%struct.A) align 8 %0)
   launch((void*)Kernel::templateMemberKernel<A>);
 }
Index: clang/test/CodeGenCUDA/kernel-args-alignment.cu
===================================================================
--- clang/test/CodeGenCUDA/kernel-args-alignment.cu
+++ clang/test/CodeGenCUDA/kernel-args-alignment.cu
@@ -36,5 +36,5 @@
 // HOST-OLD: call i32 @cudaSetupArgument({{[^,]*}}, i64 8, i64 24)
 
 // DEVICE-LABEL: @_Z6kernelc1SPi
-// DEVICE-SAME: i8{{[^,]*}}, %struct.S* byval(%struct.S) align 8{{[^,]*}}, i32*
+// DEVICE-SAME: i8{{[^,]*}}, %struct.S addrspace(101)* byref(%struct.S) align 8{{[^,]*}}, i32*
 __global__ void kernel(char a, S s, int *b) {}
Index: clang/test/CodeGen/nvptx-abi.c
===================================================================
--- clang/test/CodeGen/nvptx-abi.c
+++ clang/test/CodeGen/nvptx-abi.c
@@ -21,14 +21,14 @@
 
 void foo(float4_t x) {
 // CHECK-LABEL: @foo
-// CHECK: %struct.float4_s* byval(%struct.float4_s) align 4 %x
+// CHECK: %struct.float4_s addrspace(101)* byref(%struct.float4_s) align 4 %0
 }
 
 void fooN(float4_t x, float4_t y, float4_t z) {
 // CHECK-LABEL: @fooN
-// CHECK: %struct.float4_s* byval(%struct.float4_s) align 4 %x
-// CHECK: %struct.float4_s* byval(%struct.float4_s) align 4 %y
-// CHECK: %struct.float4_s* byval(%struct.float4_s) align 4 %z
+// CHECK: %struct.float4_s addrspace(101)* byref(%struct.float4_s) align 4 %0
+// CHECK: %struct.float4_s addrspace(101)* byref(%struct.float4_s) align 4 %1
+// CHECK: %struct.float4_s addrspace(101)* byref(%struct.float4_s) align 4 %2
 }
 
 typedef struct nested_s {
@@ -39,5 +39,5 @@
 
 void baz(nested_t x) {
 // CHECK-LABEL: @baz
-// CHECK: %struct.nested_s* byval(%struct.nested_s) align 8 %x)
+// CHECK: %struct.nested_s addrspace(101)* byref(%struct.nested_s) align 8 %0)
 }
Index: clang/lib/CodeGen/TargetInfo.cpp
===================================================================
--- clang/lib/CodeGen/TargetInfo.cpp
+++ clang/lib/CodeGen/TargetInfo.cpp
@@ -7081,7 +7081,11 @@
         return ABIArgInfo::getDirect(
             CGInfo.getCUDADeviceBuiltinTextureDeviceType());
     }
-    return getNaturalAlignIndirect(Ty, /* byval */ true);
+    return ABIArgInfo::getIndirectAliased(
+        getContext().getTypeAlignInChars(Ty),
+        getContext().getTargetAddressSpace(
+            getLangASFromTargetAS(/*ADDRESS_SPACE_PARAM*/ 101)),
+        false /*Realign*/, nullptr /*Padding*/);
   }
 
   if (const auto *EIT = Ty->getAs<ExtIntType>()) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to