jlebar updated this revision to Diff 67378. jlebar added a comment. Address review comments.
Also swap the two tests -- I had the wrong names on them. https://reviews.llvm.org/D23242 Files: clang/include/clang/Sema/Sema.h clang/lib/Sema/SemaCUDA.cpp clang/lib/Sema/SemaDeclCXX.cpp clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaOverload.cpp clang/test/CodeGenCUDA/host-device-calls-host.cu clang/test/SemaCUDA/Inputs/cuda.h clang/test/SemaCUDA/call-device-fn-from-host.cu clang/test/SemaCUDA/call-host-fn-from-device.cu
Index: clang/test/SemaCUDA/call-host-fn-from-device.cu =================================================================== --- /dev/null +++ clang/test/SemaCUDA/call-host-fn-from-device.cu @@ -0,0 +1,84 @@ +// RUN: %clang_cc1 %s --std=c++11 -triple nvptx-unknown-unknown -fcuda-is-device -emit-llvm -o - -verify + +// Note: This test won't work with -fsyntax-only, because some of these errors +// are emitted during codegen. + +#include "Inputs/cuda.h" + +extern "C" void host_fn() {} + +struct S { + S() {} + ~S() { host_fn(); } + int x; +}; + +struct T { + __host__ __device__ void hd() { host_fn(); } + // expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} + + // No error; this is (implicitly) inline and is never called, so isn't + // codegen'ed. + __host__ __device__ void hd2() { host_fn(); } + + __host__ __device__ void hd3(); + + void h() {} +}; + +__host__ __device__ void T::hd3() { + host_fn(); + // expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} +} + +template <typename T> __host__ __device__ void hd2() { host_fn(); } +// expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} +__global__ void kernel() { hd2<int>(); } + +__host__ __device__ void hd() { host_fn(); } +// expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} + +template <typename T> __host__ __device__ void hd3() { host_fn(); } +// expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} +__device__ void device_fn() { hd3<int>(); } + +// No error because this is never instantiated. +template <typename T> __host__ __device__ void hd4() { host_fn(); } + +__host__ __device__ void local_var() { + S s; + // expected-error@-1 {{reference to __host__ function 'S' in __host__ __device__ function}} +} + +__host__ __device__ void placement_new(char *ptr) { + ::new(ptr) S(); + // expected-error@-1 {{reference to __host__ function 'S' in __host__ __device__ function}} +} + +__host__ __device__ void explicit_destructor(S *s) { + s->~S(); + // expected-error@-1 {{reference to __host__ function '~S' in __host__ __device__ function}} +} + +__host__ __device__ void hd_member_fn() { + T t; + // Necessary to trigger an error on T::hd. It's (implicitly) inline, so + // isn't codegen'ed until we call it. + t.hd(); +} + +__host__ __device__ void h_member_fn() { + T t; + t.h(); + // expected-error@-1 {{reference to __host__ function 'h' in __host__ __device__ function}} +} + +__host__ __device__ void fn_ptr() { + auto* ptr = &host_fn; + // expected-error@-1 {{reference to __host__ function 'host_fn' in __host__ __device__ function}} +} + +template <typename T> +__host__ __device__ void fn_ptr_template() { + auto* ptr = &host_fn; // Not an error because the template isn't instantiated. +} Index: clang/test/SemaCUDA/call-device-fn-from-host.cu =================================================================== --- /dev/null +++ clang/test/SemaCUDA/call-device-fn-from-host.cu @@ -0,0 +1,80 @@ +// RUN: %clang_cc1 %s --std=c++11 -triple nvptx-unknown-unknown -emit-llvm -o - -verify + +// Note: This test won't work with -fsyntax-only, because some of these errors +// are emitted during codegen. + +#include "Inputs/cuda.h" + +__device__ void device_fn() {} + +struct S { + __device__ S() {} + __device__ ~S() { device_fn(); } + int x; +}; + +struct T { + __host__ __device__ void hd() { device_fn(); } + // expected-error@-1 {{reference to __device__ function 'device_fn' in __host__ __device__ function}} + + // No error; this is (implicitly) inline and is never called, so isn't + // codegen'ed. + __host__ __device__ void hd2() { device_fn(); } + + __host__ __device__ void hd3(); + + __device__ void d() {} +}; + +__host__ __device__ void T::hd3() { + device_fn(); + // expected-error@-1 {{reference to __device__ function 'device_fn' in __host__ __device__ function}} +} + +template <typename T> __host__ __device__ void hd2() { device_fn(); } +// expected-error@-1 {{reference to __device__ function 'device_fn' in __host__ __device__ function}} +void host_fn() { hd2<int>(); } + +__host__ __device__ void hd() { device_fn(); } +// expected-error@-1 {{reference to __device__ function 'device_fn' in __host__ __device__ function}} + +// No error because this is never instantiated. +template <typename T> __host__ __device__ void hd3() { device_fn(); } + +__host__ __device__ void local_var() { + S s; + // expected-error@-1 {{reference to __device__ function 'S' in __host__ __device__ function}} +} + +__host__ __device__ void placement_new(char *ptr) { + ::new(ptr) S(); + // expected-error@-1 {{reference to __device__ function 'S' in __host__ __device__ function}} +} + +__host__ __device__ void explicit_destructor(S *s) { + s->~S(); + // expected-error@-1 {{reference to __device__ function '~S' in __host__ __device__ function}} +} + +__host__ __device__ void hd_member_fn() { + T t; + // Necessary to trigger an error on T::hd. It's (implicitly) inline, so + // isn't codegen'ed until we call it. + t.hd(); +} + +__host__ __device__ void h_member_fn() { + T t; + t.d(); + // expected-error@-1 {{reference to __device__ function 'd' in __host__ __device__ function}} +} + +__host__ __device__ void fn_ptr() { + auto* ptr = &device_fn; + // expected-error@-1 {{reference to __device__ function 'device_fn' in __host__ __device__ function}} +} + +template <typename T> +__host__ __device__ void fn_ptr_template() { + auto* ptr = &device_fn; // Not an error because the template isn't instantiated. +} Index: clang/test/SemaCUDA/Inputs/cuda.h =================================================================== --- clang/test/SemaCUDA/Inputs/cuda.h +++ clang/test/SemaCUDA/Inputs/cuda.h @@ -21,4 +21,9 @@ int cudaConfigureCall(dim3 gridSize, dim3 blockSize, size_t sharedSize = 0, cudaStream_t stream = 0); + +// Device-side placement new overloads. +__device__ void *operator new(__SIZE_TYPE__, void *p) { return p; } +__device__ void *operator new[](__SIZE_TYPE__, void *p) { return p; } + #endif // !__NVCC__ Index: clang/test/CodeGenCUDA/host-device-calls-host.cu =================================================================== --- clang/test/CodeGenCUDA/host-device-calls-host.cu +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: %clang_cc1 %s -triple nvptx-unknown-unknown -fcuda-is-device -Wno-cuda-compat -emit-llvm -o - | FileCheck %s - -#include "Inputs/cuda.h" - -extern "C" -void host_function() {} - -// CHECK-LABEL: define void @hd_function_a -extern "C" -__host__ __device__ void hd_function_a() { - // CHECK: call void @host_function - host_function(); -} - -// CHECK: declare void @host_function - -// CHECK-LABEL: define void @hd_function_b -extern "C" -__host__ __device__ void hd_function_b(bool b) { if (b) host_function(); } - -// CHECK-LABEL: define void @device_function_b -extern "C" -__device__ void device_function_b() { hd_function_b(false); } - -// CHECK-LABEL: define void @global_function -extern "C" -__global__ void global_function() { - // CHECK: call void @device_function_b - device_function_b(); -} - -// CHECK: !{{[0-9]+}} = !{void ()* @global_function, !"kernel", i32 1} Index: clang/lib/Sema/SemaOverload.cpp =================================================================== --- clang/lib/Sema/SemaOverload.cpp +++ clang/lib/Sema/SemaOverload.cpp @@ -12327,19 +12327,6 @@ new (Context) CXXMemberCallExpr(Context, MemExprE, Args, ResultType, VK, RParenLoc); - // (CUDA B.1): Check for invalid calls between targets. - if (getLangOpts().CUDA) { - if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext)) { - if (!IsAllowedCUDACall(Caller, Method)) { - Diag(MemExpr->getMemberLoc(), diag::err_ref_bad_target) - << IdentifyCUDATarget(Method) << Method->getIdentifier() - << IdentifyCUDATarget(Caller); - Diag(Method->getLocation(), diag::note_previous_decl) << Method; - return ExprError(); - } - } - } - // Check for a valid return type. if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(), TheCall, Method)) Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -1747,17 +1747,9 @@ const CXXScopeSpec *SS, NamedDecl *FoundD, const TemplateArgumentListInfo *TemplateArgs) { if (getLangOpts().CUDA) - if (const FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext)) - if (const FunctionDecl *Callee = dyn_cast<FunctionDecl>(D)) { - if (!IsAllowedCUDACall(Caller, Callee)) { - Diag(NameInfo.getLoc(), diag::err_ref_bad_target) - << IdentifyCUDATarget(Callee) << D->getIdentifier() - << IdentifyCUDATarget(Caller); - Diag(D->getLocation(), diag::note_previous_decl) - << D->getIdentifier(); - return ExprError(); - } - } + if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(D)) + if (!CheckCUDACall(NameInfo.getLoc(), Callee)) + return ExprError(); bool RefersToCapturedVariable = isa<VarDecl>(D) && @@ -5124,37 +5116,35 @@ return Callee->getMinRequiredArguments() <= NumArgs; } -/// ActOnCallExpr - Handle a call to Fn with the specified array of arguments. -/// This provides the location of the left/right parens and a list of comma -/// locations. -ExprResult -Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc, - MultiExprArg ArgExprs, SourceLocation RParenLoc, - Expr *ExecConfig, bool IsExecConfig) { +static ExprResult ActOnCallExprImpl(Sema &S, Scope *Scope, Expr *Fn, + SourceLocation LParenLoc, + MultiExprArg ArgExprs, + SourceLocation RParenLoc, Expr *ExecConfig, + bool IsExecConfig) { // Since this might be a postfix expression, get rid of ParenListExprs. - ExprResult Result = MaybeConvertParenListExprToParenExpr(S, Fn); + ExprResult Result = S.MaybeConvertParenListExprToParenExpr(Scope, Fn); if (Result.isInvalid()) return ExprError(); Fn = Result.get(); - if (checkArgsForPlaceholders(*this, ArgExprs)) + if (checkArgsForPlaceholders(S, ArgExprs)) return ExprError(); - if (getLangOpts().CPlusPlus) { + if (S.getLangOpts().CPlusPlus) { // If this is a pseudo-destructor expression, build the call immediately. if (isa<CXXPseudoDestructorExpr>(Fn)) { if (!ArgExprs.empty()) { // Pseudo-destructor calls should not have any arguments. - Diag(Fn->getLocStart(), diag::err_pseudo_dtor_call_with_args) - << FixItHint::CreateRemoval( - SourceRange(ArgExprs.front()->getLocStart(), - ArgExprs.back()->getLocEnd())); + S.Diag(Fn->getLocStart(), diag::err_pseudo_dtor_call_with_args) + << FixItHint::CreateRemoval( + SourceRange(ArgExprs.front()->getLocStart(), + ArgExprs.back()->getLocEnd())); } - return new (Context) - CallExpr(Context, Fn, None, Context.VoidTy, VK_RValue, RParenLoc); + return new (S.Context) + CallExpr(S.Context, Fn, None, S.Context.VoidTy, VK_RValue, RParenLoc); } - if (Fn->getType() == Context.PseudoObjectTy) { - ExprResult result = CheckPlaceholderExpr(Fn); + if (Fn->getType() == S.Context.PseudoObjectTy) { + ExprResult result = S.CheckPlaceholderExpr(Fn); if (result.isInvalid()) return ExprError(); Fn = result.get(); } @@ -5169,50 +5159,53 @@ if (Dependent) { if (ExecConfig) { - return new (Context) CUDAKernelCallExpr( - Context, Fn, cast<CallExpr>(ExecConfig), ArgExprs, - Context.DependentTy, VK_RValue, RParenLoc); + return new (S.Context) CUDAKernelCallExpr( + S.Context, Fn, cast<CallExpr>(ExecConfig), ArgExprs, + S.Context.DependentTy, VK_RValue, RParenLoc); } else { - return new (Context) CallExpr( - Context, Fn, ArgExprs, Context.DependentTy, VK_RValue, RParenLoc); + return new (S.Context) + CallExpr(S.Context, Fn, ArgExprs, S.Context.DependentTy, VK_RValue, + RParenLoc); } } // Determine whether this is a call to an object (C++ [over.call.object]). if (Fn->getType()->isRecordType()) - return BuildCallToObjectOfClassType(S, Fn, LParenLoc, ArgExprs, - RParenLoc); + return S.BuildCallToObjectOfClassType(Scope, Fn, LParenLoc, ArgExprs, + RParenLoc); - if (Fn->getType() == Context.UnknownAnyTy) { - ExprResult result = rebuildUnknownAnyFunction(*this, Fn); + if (Fn->getType() == S.Context.UnknownAnyTy) { + ExprResult result = rebuildUnknownAnyFunction(S, Fn); if (result.isInvalid()) return ExprError(); Fn = result.get(); } - if (Fn->getType() == Context.BoundMemberTy) { - return BuildCallToMemberFunction(S, Fn, LParenLoc, ArgExprs, RParenLoc); + if (Fn->getType() == S.Context.BoundMemberTy) { + return S.BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs, + RParenLoc); } } // Check for overloaded calls. This can happen even in C due to extensions. - if (Fn->getType() == Context.OverloadTy) { + if (Fn->getType() == S.Context.OverloadTy) { OverloadExpr::FindResult find = OverloadExpr::find(Fn); - // We aren't supposed to apply this logic for if there's an '&' involved. + // We aren't supposed to apply this logic for if there'Scope an '&' + // involved. if (!find.HasFormOfMemberPointer) { OverloadExpr *ovl = find.Expression; if (UnresolvedLookupExpr *ULE = dyn_cast<UnresolvedLookupExpr>(ovl)) - return BuildOverloadedCallExpr(S, Fn, ULE, LParenLoc, ArgExprs, - RParenLoc, ExecConfig, - /*AllowTypoCorrection=*/true, - find.IsAddressOfOperand); - return BuildCallToMemberFunction(S, Fn, LParenLoc, ArgExprs, RParenLoc); + return S.BuildOverloadedCallExpr( + Scope, Fn, ULE, LParenLoc, ArgExprs, RParenLoc, ExecConfig, + /*AllowTypoCorrection=*/true, find.IsAddressOfOperand); + return S.BuildCallToMemberFunction(Scope, Fn, LParenLoc, ArgExprs, + RParenLoc); } } // If we're directly calling a function, get the appropriate declaration. - if (Fn->getType() == Context.UnknownAnyTy) { - ExprResult result = rebuildUnknownAnyFunction(*this, Fn); + if (Fn->getType() == S.Context.UnknownAnyTy) { + ExprResult result = rebuildUnknownAnyFunction(S, Fn); if (result.isInvalid()) return ExprError(); Fn = result.get(); } @@ -5236,21 +5229,21 @@ // Rewrite the function decl for this builtin by replacing parameters // with no explicit address space with the address space of the arguments // in ArgExprs. - if ((FDecl = rewriteBuiltinFunctionDecl(this, Context, FDecl, ArgExprs))) { + if ((FDecl = + rewriteBuiltinFunctionDecl(&S, S.Context, FDecl, ArgExprs))) { NDecl = FDecl; - Fn = DeclRefExpr::Create(Context, FDecl->getQualifierLoc(), - SourceLocation(), FDecl, false, - SourceLocation(), FDecl->getType(), - Fn->getValueKind(), FDecl); + Fn = DeclRefExpr::Create( + S.Context, FDecl->getQualifierLoc(), SourceLocation(), FDecl, false, + SourceLocation(), FDecl->getType(), Fn->getValueKind(), FDecl); } } } else if (isa<MemberExpr>(NakedFn)) NDecl = cast<MemberExpr>(NakedFn)->getMemberDecl(); if (FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(NDecl)) { if (CallingNDeclIndirectly && - !checkAddressOfFunctionIsAvailable(FD, /*Complain=*/true, - Fn->getLocStart())) + !S.checkAddressOfFunctionIsAvailable(FD, /*Complain=*/true, + Fn->getLocStart())) return ExprError(); // CheckEnableIf assumes that the we're passing in a sane number of args for @@ -5260,22 +5253,42 @@ // number of args looks incorrect, don't do enable_if checks; we should've // already emitted an error about the bad call. if (FD->hasAttr<EnableIfAttr>() && - isNumberOfArgsValidForCall(*this, FD, ArgExprs.size())) { - if (const EnableIfAttr *Attr = CheckEnableIf(FD, ArgExprs, true)) { - Diag(Fn->getLocStart(), - isa<CXXMethodDecl>(FD) ? - diag::err_ovl_no_viable_member_function_in_call : - diag::err_ovl_no_viable_function_in_call) - << FD << FD->getSourceRange(); - Diag(FD->getLocation(), - diag::note_ovl_candidate_disabled_by_enable_if_attr) + isNumberOfArgsValidForCall(S, FD, ArgExprs.size())) { + if (const EnableIfAttr *Attr = S.CheckEnableIf(FD, ArgExprs, true)) { + S.Diag(Fn->getLocStart(), + isa<CXXMethodDecl>(FD) + ? diag::err_ovl_no_viable_member_function_in_call + : diag::err_ovl_no_viable_function_in_call) + << FD << FD->getSourceRange(); + S.Diag(FD->getLocation(), + diag::note_ovl_candidate_disabled_by_enable_if_attr) << Attr->getCond()->getSourceRange() << Attr->getMessage(); } } } - return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc, - ExecConfig, IsExecConfig); + return S.BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc, + ExecConfig, IsExecConfig); +} + +/// ActOnCallExpr - Handle a call to Fn with the specified array of arguments. +/// This provides the location of the left/right parens and a list of comma +/// locations. +ExprResult Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc, + MultiExprArg ArgExprs, SourceLocation RParenLoc, + Expr *ExecConfig, bool IsExecConfig) { + ExprResult Ret = ActOnCallExprImpl(*this, S, Fn, LParenLoc, ArgExprs, + RParenLoc, ExecConfig, IsExecConfig); + + // If appropriate, check that this is a valid CUDA call (and emit an error if + // the call is not allowed). + if (getLangOpts().CUDA && Ret.isUsable()) + if (auto *Call = dyn_cast<CallExpr>(Ret.get())) + if (auto *FD = Call->getDirectCallee()) + if (!CheckCUDACall(Call->getLocStart(), FD)) + return ExprError(); + + return Ret; } /// ActOnAsTypeExpr - create a new asType (bitcast) from the arguments. Index: clang/lib/Sema/SemaDeclCXX.cpp =================================================================== --- clang/lib/Sema/SemaDeclCXX.cpp +++ clang/lib/Sema/SemaDeclCXX.cpp @@ -11507,6 +11507,8 @@ DeclInitType->getBaseElementTypeUnsafe()->getAsCXXRecordDecl()) && "given constructor for wrong type"); MarkFunctionReferenced(ConstructLoc, Constructor); + if (getLangOpts().CUDA && !CheckCUDACall(ConstructLoc, Constructor)) + return ExprError(); return CXXConstructExpr::Create( Context, DeclInitType, ConstructLoc, Constructor, Elidable, Index: clang/lib/Sema/SemaCUDA.cpp =================================================================== --- clang/lib/Sema/SemaCUDA.cpp +++ clang/lib/Sema/SemaCUDA.cpp @@ -480,3 +480,33 @@ NewD->addAttr(CUDAHostAttr::CreateImplicit(Context)); NewD->addAttr(CUDADeviceAttr::CreateImplicit(Context)); } + +bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) { + assert(getLangOpts().CUDA && + "Should only be called during CUDA compilation."); + assert(Callee && "Callee may not be null."); + FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext); + if (!Caller) + return true; + + Sema::CUDAFunctionPreference Pref = IdentifyCUDAPreference(Caller, Callee); + if (Pref == Sema::CFP_Never) { + Diag(Loc, diag::err_ref_bad_target) << IdentifyCUDATarget(Callee) << Callee + << IdentifyCUDATarget(Caller); + Diag(Callee->getLocation(), diag::note_previous_decl) << Callee; + return false; + } + if (Pref == Sema::CFP_WrongSide) { + // We have to do this odd dance to create our PartialDiagnostic because we + // want its storage to be allocated with operator new, not in an arena. + PartialDiagnostic PD{PartialDiagnostic::NullDiagnostic()}; + PD.Reset(diag::err_ref_bad_target); + PD << IdentifyCUDATarget(Callee) << Callee << IdentifyCUDATarget(Caller); + Caller->addDeferredDiag({Loc, std::move(PD)}); + Diag(Callee->getLocation(), diag::note_previous_decl) << Callee; + // This is not immediately an error, so return true. The deferred errors + // will be emitted if and when Caller is codegen'ed. + return true; + } + return true; +} Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -9159,6 +9159,18 @@ void maybeAddCUDAHostDeviceAttrs(Scope *S, FunctionDecl *FD, const LookupResult &Previous); + /// Check whether we're allowed to call Callee from the current context. + /// + /// If the call is never allowed in a semantically-correct program + /// (CFP_Never), emits an error and returns false. + /// + /// If the call is allowed in semantically-correct programs, but only if it's + /// never codegen'ed (CFP_WrongSide), creates a deferred diagnostic to be + /// emitted if and when the caller is codegen'ed, and returns true. + /// + /// Otherwise, returns true without emitting any diagnostics. + bool CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee); + /// Finds a function in \p Matches with highest calling priority /// from \p Caller context and erases all functions with lower /// calling priority.
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits