================ @@ -33,8 +46,114 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) { return false; } -bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, +static std::optional<int> +processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) { + ExprResult Arg = + SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(Argument)); + if (Arg.isInvalid()) + return true; + Call->setArg(Argument, Arg.get()); + + const Expr *IntArg = Arg.get(); + SmallVector<PartialDiagnosticAt, 8> Notes; + Expr::EvalResult Eval; + Eval.Diag = &Notes; + if ((!IntArg->EvaluateAsConstantExpr(Eval, SemaRef.getASTContext())) || + !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) { + SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int) + << 0 << IntArg->getSourceRange(); + for (const PartialDiagnosticAt &PDiag : Notes) + SemaRef.Diag(PDiag.first, PDiag.second); + return true; + } + return {Eval.Val.getInt().getZExtValue()}; +} + +static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) { + if (SemaRef.checkArgCount(Call, 2)) + return true; + + { + ExprResult Arg = + SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(0)); + if (Arg.isInvalid()) + return true; + Call->setArg(0, Arg.get()); + + QualType Ty = Arg.get()->getType(); + const auto *PtrTy = Ty->getAs<PointerType>(); + auto AddressSpaceNotInGeneric = [&](LangAS AS) { + if (SemaRef.LangOpts.OpenCL) + return AS != LangAS::opencl_generic; + return AS != LangAS::Default; + }; + if (!PtrTy || + AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) { + SemaRef.Diag(Arg.get()->getBeginLoc(), + diag::err_spirv_builtin_generic_cast_invalid_arg) + << Call->getSourceRange(); + return true; + } + } + + spirv::StorageClass StorageClass; + if (std::optional<int> SCInt = + processConstant32BitIntArgument(SemaRef, Call, 1); + SCInt.has_value()) { + StorageClass = static_cast<spirv::StorageClass>(SCInt.value()); + if (StorageClass != spirv::StorageClass::CrossWorkgroup && + StorageClass != spirv::StorageClass::Workgroup && + StorageClass != spirv::StorageClass::Function) { + SemaRef.Diag(Call->getArg(1)->getBeginLoc(), + diag::err_spirv_enum_not_valid) + << 0 << Call->getArg(1)->getSourceRange(); + return true; + } + } else { + return true; + } + auto RT = Call->getArg(0)->getType(); + RT = RT->getPointeeType(); + auto Qual = RT.getQualifiers(); + LangAS AddrSpace; + switch (static_cast<spirv::StorageClass>(StorageClass)) { + case spirv::StorageClass::CrossWorkgroup: + AddrSpace = + SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global; + break; + case spirv::StorageClass::Workgroup: + AddrSpace = + SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local; + break; + case spirv::StorageClass::Function: + AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private + : LangAS::opencl_private; + break; + default: + llvm_unreachable("Invalid builtin function"); + } + Qual.setAddressSpace(AddrSpace); + Call->setType(SemaRef.getASTContext().getPointerType( + SemaRef.getASTContext().getQualifiedType(RT.getUnqualifiedType(), Qual))); + + return false; +} + +bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI, + unsigned BuiltinID, CallExpr *TheCall) { + if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin) + if (TI.getTriple().getArch() != llvm::Triple::spirv) { ---------------- Naghasan wrote:
done https://github.com/llvm/llvm-project/pull/137805 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits