https://github.com/cassiebeckley updated https://github.com/llvm/llvm-project/pull/134034
>From 78ac1bc4225b41bc4b9fbd9fd9ab9dc82a2953ca Mon Sep 17 00:00:00 2001 From: Cassandra Beckley <cbeck...@google.com> Date: Tue, 1 Apr 2025 23:12:02 -0700 Subject: [PATCH 1/2] [HLSL] Implement `SpirvType` and `SpirvOpaqueType` This implements the design proposed by [Representing SpirvType in Clang's Type System](https://github.com/llvm/wg-hlsl/pull/181). It creates `HLSLInlineSpirvType` as a new `Type` subclass, and `__hlsl_spirv_type` as a new builtin type template to create such a type. This new type is lowered to the `spirv.Type` target extension type, as described in [Target Extension Types for Inline SPIR-V and Decorated Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md). --- clang/include/clang-c/Index.h | 3 +- clang/include/clang/AST/ASTContext.h | 5 + clang/include/clang/AST/ASTNodeTraverser.h | 18 +++ clang/include/clang/AST/PropertiesBase.td | 1 + clang/include/clang/AST/RecursiveASTVisitor.h | 11 ++ clang/include/clang/AST/Type.h | 142 +++++++++++++++++- clang/include/clang/AST/TypeLoc.h | 19 +++ clang/include/clang/AST/TypeProperties.td | 18 +++ clang/include/clang/Basic/BuiltinTemplates.td | 18 ++- .../clang/Basic/DiagnosticSemaKinds.td | 3 + clang/include/clang/Basic/TypeNodes.td | 1 + .../clang/Serialization/ASTRecordReader.h | 2 + .../clang/Serialization/ASTRecordWriter.h | 14 ++ .../clang/Serialization/TypeBitCodes.def | 1 + clang/lib/AST/ASTContext.cpp | 59 ++++++++ clang/lib/AST/ASTImporter.cpp | 42 ++++++ clang/lib/AST/ASTStructuralEquivalence.cpp | 17 +++ clang/lib/AST/ExprConstant.cpp | 1 + clang/lib/AST/ItaniumMangle.cpp | 40 ++++- clang/lib/AST/MicrosoftMangle.cpp | 5 + clang/lib/AST/Type.cpp | 14 ++ clang/lib/AST/TypePrinter.cpp | 48 ++++++ clang/lib/CodeGen/CGDebugInfo.cpp | 8 + clang/lib/CodeGen/CGDebugInfo.h | 1 + clang/lib/CodeGen/CodeGenFunction.cpp | 2 + clang/lib/CodeGen/CodeGenTypes.cpp | 6 + clang/lib/CodeGen/ItaniumCXXABI.cpp | 2 + clang/lib/CodeGen/Targets/SPIR.cpp | 90 ++++++++++- clang/lib/Headers/CMakeLists.txt | 1 + clang/lib/Headers/hlsl.h | 4 + clang/lib/Headers/hlsl/hlsl_spirv.h | 30 ++++ clang/lib/Sema/SemaExpr.cpp | 1 + clang/lib/Sema/SemaLookup.cpp | 21 ++- clang/lib/Sema/SemaTemplate.cpp | 103 ++++++++++++- clang/lib/Sema/SemaTemplateDeduction.cpp | 2 + clang/lib/Sema/SemaType.cpp | 1 + clang/lib/Sema/TreeTransform.h | 7 + clang/lib/Serialization/ASTReader.cpp | 9 ++ clang/lib/Serialization/ASTWriter.cpp | 4 + .../test/AST/HLSL/Inputs/pch_spirv_type.hlsl | 6 + clang/test/AST/HLSL/ast-dump-SpirvType.hlsl | 27 ++++ clang/test/AST/HLSL/pch_spirv_type.hlsl | 17 +++ clang/test/AST/HLSL/vector-alias.hlsl | 105 +++++++------ .../inline/SpirvType.alignment.hlsl | 16 ++ .../inline/SpirvType.dx.error.hlsl | 12 ++ clang/test/CodeGenHLSL/inline/SpirvType.hlsl | 68 +++++++++ .../inline/SpirvType.incomplete.hlsl | 14 ++ .../inline/SpirvType.literal.error.hlsl | 11 ++ clang/tools/libclang/CIndex.cpp | 5 + clang/tools/libclang/CXType.cpp | 1 + .../TableGen/ClangBuiltinTemplatesEmitter.cpp | 72 +++++++-- 51 files changed, 1052 insertions(+), 76 deletions(-) create mode 100644 clang/lib/Headers/hlsl/hlsl_spirv.h create mode 100644 clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl create mode 100644 clang/test/AST/HLSL/ast-dump-SpirvType.hlsl create mode 100644 clang/test/AST/HLSL/pch_spirv_type.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h index 38e2417dcd181..757f8a3afc758 100644 --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -3034,7 +3034,8 @@ enum CXTypeKind { /* HLSL Types */ CXType_HLSLResource = 179, - CXType_HLSLAttributedResource = 180 + CXType_HLSLAttributedResource = 180, + CXType_HLSLInlineSpirv = 181 }; /** diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h index a24f30815e6b9..c62f9f7672010 100644 --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> { DependentBitIntTypes; mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes; llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes; + llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes; mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes; @@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase<ASTContext> { QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs); + QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef<SpirvOperand> Operands); + QualType getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index f086d8134a64b..fd9108221590e 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -450,6 +450,24 @@ class ASTNodeTraverser if (!Contained.isNull()) Visit(Contained); } + void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: + case SpirvOperandKind::kLiteral: + break; + + case SpirvOperandKind::kTypeId: + Visit(Operand.getResultType()); + break; + + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + } + } + } void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {} void VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) { diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td index 5171555008ac9..7d5e6671fec7d 100644 --- a/clang/include/clang/AST/PropertiesBase.td +++ b/clang/include/clang/AST/PropertiesBase.td @@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">; def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">; def VectorKind : EnumPropertyType<"VectorKind">; def TypeCoupledDeclRefInfo : PropertyType; +def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; } def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> { let BufferElementTypes = [ QualType ]; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 0530996ed20d3..255e39a46db09 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType, DEF_TRAVERSE_TYPE(HLSLAttributedResourceType, { TRY_TO(TraverseType(T->getWrappedType())); }) +DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, { + for (auto &Operand : T->getOperands()) { + if (Operand.isConstant() || Operand.isType()) { + TRY_TO(TraverseType(Operand.getResultType())); + } + } +}) + DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); }) DEF_TRAVERSE_TYPE(MacroQualifiedType, @@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType, DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType, { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); }) +DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType, + { TRY_TO(TraverseType(TL.getType())); }) + DEF_TRAVERSE_TYPELOC(ElaboratedType, { if (TL.getQualifierLoc()) { TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc())); diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h index cfd417068abb7..f351e68d5297d 100644 --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase { bool isHLSLSpecificType() const; // Any HLSL specific type bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type bool isHLSLAttributedResourceType() const; + bool isHLSLInlineSpirvType() const; bool isHLSLResourceRecord() const; bool isHLSLIntangibleType() const; // Any HLSL intangible type (builtin, array, class) @@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode { findHandleTypeOnResource(const Type *RT); }; +/// Instances of this class represent operands to a SPIR-V type instruction. +class SpirvOperand { +public: + enum SpirvOperandKind : unsigned char { + kInvalid, ///< Uninitialized. + kConstantId, ///< Integral value to represent as a SPIR-V OpConstant + ///< instruction ID. + kLiteral, ///< Integral value to represent as an immediate literal. + kTypeId, ///< Type to represent as a SPIR-V type ID. + + kMax, + }; + +private: + SpirvOperandKind Kind = kInvalid; + + QualType ResultType; + llvm::APInt Value; // Signedness of constants is represented by ResultType. + +public: + SpirvOperand() : Kind(kInvalid), ResultType() {} + + SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value) + : Kind(Kind), ResultType(ResultType), Value(Value) {} + + SpirvOperand(const SpirvOperand &Other) { *this = Other; } + ~SpirvOperand() {} + + SpirvOperand &operator=(const SpirvOperand &Other) { + this->Kind = Other.Kind; + this->ResultType = Other.ResultType; + this->Value = Other.Value; + return *this; + } + + bool operator==(const SpirvOperand &Other) const { + return Kind == Other.Kind && ResultType == Other.ResultType && + Value == Other.Value; + } + + bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); } + + SpirvOperandKind getKind() const { return Kind; } + + bool isValid() const { return Kind != kInvalid && Kind < kMax; } + bool isConstant() const { return Kind == kConstantId; } + bool isLiteral() const { return Kind == kLiteral; } + bool isType() const { return Kind == kTypeId; } + + llvm::APInt getValue() const { + assert((isConstant() || isLiteral()) && + "This is not an operand with a value!"); + return Value; + } + + QualType getResultType() const { + assert((isConstant() || isType()) && + "This is not an operand with a result type!"); + return ResultType; + } + + static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) { + return SpirvOperand(kConstantId, ResultType, Val); + } + + static SpirvOperand createLiteral(llvm::APInt Val) { + return SpirvOperand(kLiteral, QualType(), Val); + } + + static SpirvOperand createType(QualType T) { + return SpirvOperand(kTypeId, T, llvm::APSInt()); + } + + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(Kind); + ID.AddPointer(ResultType.getAsOpaquePtr()); + Value.Profile(ID); + } +}; + +/// Represents an arbitrary, user-specified SPIR-V type instruction. +class HLSLInlineSpirvType final + : public Type, + public llvm::FoldingSetNode, + private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> { + friend class ASTContext; // ASTContext creates these + friend TrailingObjects; + +private: + uint32_t Opcode; + uint32_t Size; + uint32_t Alignment; + size_t NumOperands; + + HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) + : Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode), + Size(Size), Alignment(Alignment), NumOperands(Operands.size()) { + for (size_t I = 0; I < NumOperands; I++) { + getTrailingObjects<SpirvOperand>()[I] = Operands[I]; + } + } + +public: + uint32_t getOpcode() const { return Opcode; } + uint32_t getSize() const { return Size; } + uint32_t getAlignment() const { return Alignment; } + ArrayRef<SpirvOperand> getOperands() const { + return {getTrailingObjects<SpirvOperand>(), NumOperands}; + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Opcode, Size, Alignment, getOperands()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode, + uint32_t Size, uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) { + ID.AddInteger(Opcode); + ID.AddInteger(Size); + ID.AddInteger(Alignment); + for (auto &Operand : Operands) { + Operand.Profile(ID); + } + } + + static bool classof(const Type *T) { + return T->getTypeClass() == HLSLInlineSpirv; + } +}; + class TemplateTypeParmType : public Type, public llvm::FoldingSetNode { friend class ASTContext; // ASTContext creates these @@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const { } inline bool Type::isHLSLSpecificType() const { - return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType(); + return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() || + isHLSLInlineSpirvType(); } inline bool Type::isHLSLAttributedResourceType() const { return isa<HLSLAttributedResourceType>(this); } +inline bool Type::isHLSLInlineSpirvType() const { + return isa<HLSLInlineSpirvType>(this); +} + inline bool Type::isTemplateTypeParmType() const { return isa<TemplateTypeParmType>(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h index 92661b8b13fe0..53c7ea8c65df2 100644 --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc } }; +struct HLSLInlineSpirvTypeLocInfo { + SourceLocation Loc; +}; // Nothing. + +class HLSLInlineSpirvTypeLoc + : public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc, + HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> { +public: + SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; } + void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; } + + SourceRange getLocalSourceRange() const { + return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc()); + } + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setSpirvTypeLoc(loc); + } +}; + struct ObjCObjectTypeLocInfo { SourceLocation TypeArgsLAngleLoc; SourceLocation TypeArgsRAngleLoc; diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td index 391fd26a086f7..784c2104f1bb2 100644 --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in { }]>; } +let Class = HLSLInlineSpirvType in { + def : Property<"opcode", UInt32> { + let Read = [{ node->getOpcode() }]; + } + def : Property<"size", UInt32> { + let Read = [{ node->getSize() }]; + } + def : Property<"alignment", UInt32> { + let Read = [{ node->getAlignment() }]; + } + def : Property<"operands", Array<HLSLSpirvOperand>> { + let Read = [{ node->getOperands() }]; + } + def : Creator<[{ + return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands); + }]>; +} + let Class = DependentAddressSpaceType in { def : Property<"pointeeType", QualType> { let Read = [{ node->getPointeeType() }]; diff --git a/clang/include/clang/Basic/BuiltinTemplates.td b/clang/include/clang/Basic/BuiltinTemplates.td index d46ce063d2f7e..5b9672b395955 100644 --- a/clang/include/clang/Basic/BuiltinTemplates.td +++ b/clang/include/clang/Basic/BuiltinTemplates.td @@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> { } def SizeT : BuiltinNTTP<"size_t"> {} +def Uint32T: BuiltinNTTP<"uint32_t"> {} class BuiltinTemplate<list<TemplateArg> template_head> { list<TemplateArg> TemplateHead = template_head; } +class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>; + +class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>; + // template <template <class T, T... Ints> IntSeq, class T, T N> -def __make_integer_seq : BuiltinTemplate< +def __make_integer_seq : CPlusPlusBuiltinTemplate< [Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>; // template <size_t, class... T> -def __type_pack_element : BuiltinTemplate< +def __type_pack_element : CPlusPlusBuiltinTemplate< [SizeT, Class<"T", /*is_variadic=*/1>]>; // template <template <class... Args> BaseTemplate, // template <class TypeMember> HasTypeMember, // class HasNoTypeMember // class... Ts> -def __builtin_common_type : BuiltinTemplate< +def __builtin_common_type : CPlusPlusBuiltinTemplate< [Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">, Template<[Class<"TypeMember">], "HasTypeMember">, Class<"HasNoTypeMember">, Class<"Ts", /*is_variadic=*/1>]>; + +// template <uint32_t Opcode, +// uint32_t Size, +// uint32_t Alignment, +// typename ...Operands> +def __hlsl_spirv_type : HLSLBuiltinTemplate< +[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 265bed2df43cf..287e139f02a2c 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12709,6 +12709,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error< def err_invalid_hlsl_resource_type: Error< "invalid __hlsl_resource_t type attributes">; +def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">; +def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">; + // Layout randomization diagnostics. def err_non_designated_init_used : Error< "a randomized struct can only be initialized with a designated initializer">; diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td index 7e550ca2992f3..567b8a5ca5a4d 100644 --- a/clang/include/clang/Basic/TypeNodes.td +++ b/clang/include/clang/Basic/TypeNodes.td @@ -94,6 +94,7 @@ def ElaboratedType : TypeNode<Type>, NeverCanonical; def AttributedType : TypeNode<Type>, NeverCanonical; def BTFTagAttributedType : TypeNode<Type>, NeverCanonical; def HLSLAttributedResourceType : TypeNode<Type>; +def HLSLInlineSpirvType : TypeNode<Type>; def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType; def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical; def SubstTemplateTypeParmPackType : TypeNode<Type>, AlwaysDependent; diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h index 7117b7246739b..79d33315d4fee 100644 --- a/clang/include/clang/Serialization/ASTRecordReader.h +++ b/clang/include/clang/Serialization/ASTRecordReader.h @@ -214,6 +214,8 @@ class ASTRecordReader TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo(); + SpirvOperand readHLSLSpirvOperand(); + /// Read a declaration name, advancing Idx. // DeclarationName readDeclarationName(); (inherited) DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name); diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h index 84d77e46016b7..9653b709d3ef5 100644 --- a/clang/include/clang/Serialization/ASTRecordWriter.h +++ b/clang/include/clang/Serialization/ASTRecordWriter.h @@ -151,6 +151,20 @@ class ASTRecordWriter writeBool(Info.isDeref()); } + void writeHLSLSpirvOperand(SpirvOperand Op) { + QualType ResultType; + llvm::APInt Value; + + if (Op.isConstant() || Op.isType()) + ResultType = Op.getResultType(); + if (Op.isConstant() || Op.isLiteral()) + Value = Op.getValue(); + + Record->push_back(Op.getKind()); + writeQualType(ResultType); + writeAPInt(Value); + } + /// Emit a source range. void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) { return Writer->AddSourceRange(Range, *Record, Seq); diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def index 3c78b87805010..b8cde2e370960 100644 --- a/clang/include/clang/Serialization/TypeBitCodes.def +++ b/clang/include/clang/Serialization/TypeBitCodes.def @@ -68,5 +68,6 @@ TYPE_BIT_CODE(PackIndexing, PACK_INDEXING, 56) TYPE_BIT_CODE(CountAttributed, COUNT_ATTRIBUTED, 57) TYPE_BIT_CODE(ArrayParameter, ARRAY_PARAMETER, 58) TYPE_BIT_CODE(HLSLAttributedResource, HLSLRESOURCE_ATTRIBUTED, 59) +TYPE_BIT_CODE(HLSLInlineSpirv, HLSL_INLINE_SPIRV, 60) #undef TYPE_BIT_CODE diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index 552b5823add36..fb6a7b5a34175 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -2454,6 +2454,19 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const { return getTypeInfo( cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr()); + case Type::HLSLInlineSpirv: { + const auto *ST = cast<HLSLInlineSpirvType>(T); + // Size is specified in bytes, convert to bits + Width = ST->getSize() * 8; + Align = ST->getAlignment(); + if (Width == 0 && Align == 0) { + // We are defaulting to laying out opaque SPIR-V types as 32-bit ints. + Width = 32; + Align = 32; + } + break; + } + case Type::Atomic: { // Start with the base type information. TypeInfo Info = getTypeInfo(cast<AtomicType>(T)->getValueType()); @@ -3458,6 +3471,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx, return; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("should never get here"); break; case Type::DeducedTemplateSpecialization: @@ -4179,6 +4193,7 @@ QualType ASTContext::getVariableArrayDecayedType(QualType type) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("type should never be variably-modified"); // These types can be variably-modified but should never need to @@ -5444,6 +5459,31 @@ QualType ASTContext::getHLSLAttributedResourceType( return QualType(Ty, 0); } + +QualType ASTContext::getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) { + llvm::FoldingSetNodeID ID; + HLSLInlineSpirvType::Profile(ID, Opcode, Size, Alignment, Operands); + + void *InsertPos = nullptr; + HLSLInlineSpirvType *Ty = + HLSLInlineSpirvTypes.FindNodeOrInsertPos(ID, InsertPos); + if (Ty) + return QualType(Ty, 0); + + unsigned size = sizeof(HLSLInlineSpirvType); + size += Operands.size() * sizeof(SpirvOperand); + void *mem = Allocate(size, alignof(HLSLInlineSpirvType)); + + Ty = new (mem) HLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + + Types.push_back(Ty); + HLSLInlineSpirvTypes.InsertNode(Ty, InsertPos); + + return QualType(Ty, 0); +} + /// Retrieve a substitution-result type. QualType ASTContext::getSubstTemplateTypeParmType( QualType Replacement, Decl *AssociatedDecl, unsigned Index, @@ -9335,6 +9375,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S, return; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("unexpected type"); case Type::ArrayParameter: @@ -11763,6 +11804,22 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer, return LHS; return {}; } + case Type::HLSLInlineSpirv: + const HLSLInlineSpirvType *LHSTy = LHS->castAs<HLSLInlineSpirvType>(); + const HLSLInlineSpirvType *RHSTy = RHS->castAs<HLSLInlineSpirvType>(); + + if (LHSTy->getOpcode() == RHSTy->getOpcode() && + LHSTy->getSize() == RHSTy->getSize() && + LHSTy->getAlignment() == RHSTy->getAlignment()) { + for (size_t I = 0; I < LHSTy->getOperands().size(); I++) { + if (LHSTy->getOperands()[I] != RHSTy->getOperands()[I]) { + return {}; + } + } + + return LHS; + } + return {}; } llvm_unreachable("Invalid Type::Class!"); @@ -13746,6 +13803,7 @@ static QualType getCommonNonSugarTypeNode(ASTContext &Ctx, const Type *X, SUGAR_FREE_TYPE(SubstTemplateTypeParmPack) SUGAR_FREE_TYPE(UnresolvedUsing) SUGAR_FREE_TYPE(HLSLAttributedResource) + SUGAR_FREE_TYPE(HLSLInlineSpirv) #undef SUGAR_FREE_TYPE #define NON_UNIQUE_TYPE(Class) UNEXPECTED_TYPE(Class, "non-unique") NON_UNIQUE_TYPE(TypeOfExpr) @@ -14089,6 +14147,7 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X, CANONICAL_TYPE(FunctionProto) CANONICAL_TYPE(IncompleteArray) CANONICAL_TYPE(HLSLAttributedResource) + CANONICAL_TYPE(HLSLInlineSpirv) CANONICAL_TYPE(LValueReference) CANONICAL_TYPE(ObjCInterface) CANONICAL_TYPE(ObjCObject) diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index 81acb013b0f7d..8afb29cef24a4 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -1832,6 +1832,48 @@ ExpectedType clang::ASTNodeImporter::VisitHLSLAttributedResourceType( ToWrappedType, ToContainedType, ToAttrs); } +ExpectedType clang::ASTNodeImporter::VisitHLSLInlineSpirvType( + const clang::HLSLInlineSpirvType *T) { + Error Err = Error::success(); + + uint32_t ToOpcode = T->getOpcode(); + uint32_t ToSize = T->getSize(); + uint32_t ToAlignment = T->getAlignment(); + + size_t NumOperands = T->getOperands().size(); + + llvm::SmallVector<SpirvOperand> ToOperands; + + size_t I = 0; + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: + ToOperands.push_back(SpirvOperand::createConstant( + importChecked(Err, Operand.getResultType()), Operand.getValue())); + break; + case SpirvOperandKind::kLiteral: + ToOperands.push_back(SpirvOperand::createLiteral(Operand.getValue())); + break; + case SpirvOperandKind::kTypeId: + ToOperands.push_back(SpirvOperand::createType( + importChecked(Err, Operand.getResultType()))); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + } + + if (Err) + return std::move(Err); + } + + assert(I == NumOperands); + + return Importer.getToContext().getHLSLInlineSpirvType( + ToOpcode, ToSize, ToAlignment, ToOperands); +} + ExpectedType clang::ASTNodeImporter::VisitConstantMatrixType( const clang::ConstantMatrixType *T) { ExpectedType ToElementTypeOrErr = import(T->getElementType()); diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp index c769722521d9c..f213368f3e1cc 100644 --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -1119,6 +1119,23 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context, return false; break; + case Type::HLSLInlineSpirv: + if (cast<HLSLInlineSpirvType>(T1)->getOpcode() != + cast<HLSLInlineSpirvType>(T2)->getOpcode() || + cast<HLSLInlineSpirvType>(T1)->getSize() != + cast<HLSLInlineSpirvType>(T2)->getSize() || + cast<HLSLInlineSpirvType>(T1)->getAlignment() != + cast<HLSLInlineSpirvType>(T2)->getAlignment()) + return false; + for (size_t I = 0; I < cast<HLSLInlineSpirvType>(T1)->getOperands().size(); + I++) { + if (cast<HLSLInlineSpirvType>(T1)->getOperands()[I] != + cast<HLSLInlineSpirvType>(T2)->getOperands()[I]) { + return false; + } + } + break; + case Type::Paren: if (!IsStructurallyEquivalent(Context, cast<ParenType>(T1)->getInnerType(), cast<ParenType>(T2)->getInnerType())) diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 80ece3c4ed7e1..96055b03ccd73 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -12439,6 +12439,7 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T, case Type::ObjCObjectPointer: case Type::Pipe: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // Classify all other types that don't fit into the regular // classification the same way. return GCCTypeClass::None; diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index b81981606866a..fe72305cd7535 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2453,6 +2453,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty, case Type::Attributed: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: case Type::Auto: case Type::DeducedTemplateSpecialization: case Type::PackExpansion: @@ -4654,6 +4655,44 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) { mangleType(T->getWrappedType()); } +void CXXNameMangler::mangleType(const HLSLInlineSpirvType *T) { + SmallString<20> TypeNameStr; + llvm::raw_svector_ostream TypeNameOS(TypeNameStr); + + TypeNameOS << "spirv_type"; + + TypeNameOS << "_" << T->getOpcode(); + TypeNameOS << "_" << T->getSize(); + TypeNameOS << "_" << T->getAlignment(); + + mangleVendorType(TypeNameStr); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: + mangleVendorQualifier("_Const"); + mangleIntegerLiteral(Operand.getResultType(), + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::kLiteral: + mangleVendorQualifier("_Lit"); + mangleIntegerLiteral(Context.getASTContext().IntTy, + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::kTypeId: + mangleVendorQualifier("_Type"); + mangleType(Operand.getResultType()); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + break; + } + TypeNameOS << Operand.getKind(); + } +} + void CXXNameMangler::mangleIntegerLiteral(QualType T, const llvm::APSInt &Value) { // <expr-primary> ::= L <type> <value number> E # integer literal @@ -4667,7 +4706,6 @@ void CXXNameMangler::mangleIntegerLiteral(QualType T, mangleNumber(Value); } Out << 'E'; - } void CXXNameMangler::mangleMemberExprBase(const Expr *Base, bool IsArrow) { diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index 7e964124a9fec..f6b5937621154 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -3739,6 +3739,11 @@ void MicrosoftCXXNameMangler::mangleType(const HLSLAttributedResourceType *T, llvm_unreachable("HLSL uses Itanium name mangling"); } +void MicrosoftCXXNameMangler::mangleType(const HLSLInlineSpirvType *T, + Qualifiers, SourceRange Range) { + llvm_unreachable("HLSL uses Itanium name mangling"); +} + // <this-adjustment> ::= <no-adjustment> | <static-adjustment> | // <virtual-adjustment> // <no-adjustment> ::= A # private near diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 667ffc0e599a6..f9a6ccdb7bc6b 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4654,6 +4654,8 @@ static CachedProperties computeCachedProperties(const Type *T) { return Cache::get(cast<PipeType>(T)->getElementType()); case Type::HLSLAttributedResource: return Cache::get(cast<HLSLAttributedResourceType>(T)->getWrappedType()); + case Type::HLSLInlineSpirv: + return CachedProperties(Linkage::External, false); } llvm_unreachable("unhandled type class"); @@ -4748,6 +4750,17 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) { return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T) ->getContainedType() ->getCanonicalTypeInternal()); + case Type::HLSLInlineSpirv: + return LinkageInfo::external(); + { + const auto *ST = cast<HLSLInlineSpirvType>(T); + LinkageInfo LV = LinkageInfo::external(); + for (auto &Operand : ST->getOperands()) { + if (Operand.isConstant() || Operand.isType()) + LV.merge(computeTypeLinkageInfo(Operand.getResultType())); + } + return LV; + } } llvm_unreachable("unhandled type class"); @@ -4938,6 +4951,7 @@ bool Type::canHaveNullability(bool ResultIfUnknown) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return false; } llvm_unreachable("bad type kind!"); diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index 4ec252e3f89b5..01fd90c40e4a5 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -247,6 +247,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T, case Type::DependentBitInt: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: CanPrefixQualifiers = true; break; @@ -2135,6 +2136,53 @@ void TypePrinter::printHLSLAttributedResourceAfter( } } +void TypePrinter::printHLSLInlineSpirvBefore(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + OS << "__hlsl_spirv_type<" << T->getOpcode(); + + OS << ", " << T->getSize(); + OS << ", " << T->getAlignment(); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + OS << ", "; + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: { + QualType ConstantType = Operand.getResultType(); + OS << "vk::integral_constant<"; + printBefore(ConstantType, OS); + printAfter(ConstantType, OS); + OS << ", "; + OS << Operand.getValue(); + OS << ">"; + break; + } + case SpirvOperandKind::kLiteral: + OS << "vk::Literal<vk::integral_constant<uint, "; + OS << Operand.getValue(); + OS << ">>"; + break; + case SpirvOperandKind::kTypeId: { + QualType Type = Operand.getResultType(); + printBefore(Type, OS); + printAfter(Type, OS); + break; + } + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + break; + } + } + + OS << ">"; +} + +void TypePrinter::printHLSLInlineSpirvAfter(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + // nothing to do +} + void TypePrinter::printObjCInterfaceBefore(const ObjCInterfaceType *T, raw_ostream &OS) { OS << T->getDecl()->getName(); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp index 52aa956121d73..179d93b2c3672 100644 --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -3521,6 +3521,12 @@ llvm::DIType *CGDebugInfo::CreateType(const HLSLAttributedResourceType *Ty, return getOrCreateType(Ty->getWrappedType(), U); } +llvm::DIType *CGDebugInfo::CreateType(const HLSLInlineSpirvType *Ty, + llvm::DIFile *U) { + // Debug information unneeded. + return nullptr; +} + llvm::DIType *CGDebugInfo::CreateEnumType(const EnumType *Ty) { const EnumDecl *ED = Ty->getDecl(); @@ -3874,6 +3880,8 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) { return CreateType(cast<TemplateSpecializationType>(Ty), Unit); case Type::HLSLAttributedResource: return CreateType(cast<HLSLAttributedResourceType>(Ty), Unit); + case Type::HLSLInlineSpirv: + return CreateType(cast<HLSLInlineSpirvType>(Ty), Unit); case Type::CountAttributed: case Type::Auto: diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h index b287ce7b92eee..7a63fa4b00278 100644 --- a/clang/lib/CodeGen/CGDebugInfo.h +++ b/clang/lib/CodeGen/CGDebugInfo.h @@ -198,6 +198,7 @@ class CGDebugInfo { llvm::DIType *CreateType(const FunctionType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const HLSLAttributedResourceType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const HLSLInlineSpirvType *Ty, llvm::DIFile *F); /// Get structure or union type. llvm::DIType *CreateType(const RecordType *Tyg); diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index dcf523f56bf1e..0fcf434e1d033 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -283,6 +283,7 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) { case Type::Pipe: case Type::BitInt: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return TEK_Scalar; // Complexes. @@ -2452,6 +2453,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) { case Type::ObjCInterface: case Type::ObjCObjectPointer: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 11cf5758b6d3a..3a22888647445 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -767,6 +767,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { break; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: ResultType = CGM.getHLSLRuntime().convertHLSLSpecificType(Ty); break; } @@ -877,6 +878,11 @@ bool CodeGenTypes::isZeroInitializable(QualType T) { if (const MemberPointerType *MPT = T->getAs<MemberPointerType>()) return getCXXABI().isZeroInitializable(MPT); + // HLSL Inline SPIR-V types are non-zero-initializable. + if (T->getAs<HLSLInlineSpirvType>()) { + return false; + } + // Everything else is okay. return true; } diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 38e3a63ebfb11..71ce242b10b99 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3959,6 +3959,7 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty, break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support virtual functions"); } @@ -4237,6 +4238,7 @@ llvm::Constant *ItaniumRTTIBuilder::BuildTypeInfo( break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support RTTI"); } diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp index 225d9dfbd980b..1c1f243dc84c7 100644 --- a/clang/lib/CodeGen/Targets/SPIR.cpp +++ b/clang/lib/CodeGen/Targets/SPIR.cpp @@ -369,14 +369,102 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getOpenCLType(CodeGenModule &CGM, return nullptr; } +// Gets a spirv.IntegralConstant or spirv.Literal. If IntegralType is present, +// returns an IntegralConstant, otherwise returns a Literal. +static llvm::Type *getInlineSpirvConstant(CodeGenModule &CGM, + llvm::Type *IntegralType, + llvm::APInt Value) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + // Convert the APInt value to an array of uint32_t words + llvm::SmallVector<uint32_t> Words; + + while (Value.ugt(0)) { + uint32_t Word = Value.trunc(32).getZExtValue(); + Value.lshrInPlace(32); + + Words.push_back(Word); + } + if (Words.size() == 0) + Words.push_back(0); + + if (IntegralType) { + return llvm::TargetExtType::get(Ctx, "spirv.IntegralConstant", + {IntegralType}, Words); + } else { + return llvm::TargetExtType::get(Ctx, "spirv.Literal", {}, Words); + } +} + +static llvm::Type *getInlineSpirvType(CodeGenModule &CGM, + const HLSLInlineSpirvType *SpirvType) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + llvm::SmallVector<llvm::Type *> Operands; + + for (auto &Operand : SpirvType->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + llvm::Type *Result = nullptr; + switch (Operand.getKind()) { + case SpirvOperandKind::kConstantId: { + llvm::Type *IntegralType = + CGM.getTypes().ConvertType(Operand.getResultType()); + llvm::APInt Value = Operand.getValue(); + + Result = getInlineSpirvConstant(CGM, IntegralType, Value); + break; + } + case SpirvOperandKind::kLiteral: { + llvm::APInt Value = Operand.getValue(); + Result = getInlineSpirvConstant(CGM, nullptr, Value); + break; + } + case SpirvOperandKind::kTypeId: { + QualType TypeOperand = Operand.getResultType(); + if (auto *RT = TypeOperand->getAs<RecordType>()) { + auto *RD = RT->getDecl(); + assert(RD->isCompleteDefinition() && + "Type completion should have been required in Sema"); + + const FieldDecl *HandleField = RD->findFirstNamedDataMember(); + if (HandleField) { + QualType ResourceType = HandleField->getType(); + if (ResourceType->getAs<HLSLAttributedResourceType>()) { + TypeOperand = ResourceType; + } + } + } + Result = CGM.getTypes().ConvertType(TypeOperand); + break; + } + default: + llvm_unreachable("HLSLInlineSpirvType had invalid operand!"); + break; + } + + assert(Result); + Operands.push_back(Result); + } + + return llvm::TargetExtType::get(Ctx, "spirv.Type", Operands, + {SpirvType->getOpcode(), SpirvType->getSize(), + SpirvType->getAlignment()}); +} + llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType( CodeGenModule &CGM, const Type *Ty, const SmallVector<int32_t> *Packoffsets) const { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + if (auto *SpirvType = dyn_cast<HLSLInlineSpirvType>(Ty)) { + return getInlineSpirvType(CGM, SpirvType); + } + auto *ResType = dyn_cast<HLSLAttributedResourceType>(Ty); if (!ResType) return nullptr; - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); const HLSLAttributedResourceType::Attributes &ResAttrs = ResType->getAttrs(); switch (ResAttrs.ResourceClass) { case llvm::dxil::ResourceClass::UAV: diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt index acf49e40c447e..e9555cf7e0898 100644 --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -91,6 +91,7 @@ set(hlsl_subdir_files hlsl/hlsl_intrinsic_helpers.h hlsl/hlsl_intrinsics.h hlsl/hlsl_detail.h + hlsl/hlsl_spirv.h ) set(hlsl_files ${hlsl_h} diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h index b494b4d0f78bb..684d29d5ed55b 100644 --- a/clang/lib/Headers/hlsl.h +++ b/clang/lib/Headers/hlsl.h @@ -27,6 +27,10 @@ #endif #include "hlsl/hlsl_intrinsics.h" +#ifdef __spirv__ +#include "hlsl/hlsl_spirv.h" +#endif // __spirv__ + #if defined(__clang__) #pragma clang diagnostic pop #endif diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h new file mode 100644 index 0000000000000..8a71699a4ed5c --- /dev/null +++ b/clang/lib/Headers/hlsl/hlsl_spirv.h @@ -0,0 +1,30 @@ +//===----- hlsl_spirv.h - HLSL definitions for SPIR-V target --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _HLSL_HLSL_SPIRV_H_ +#define _HLSL_HLSL_SPIRV_H_ + +namespace hlsl { + namespace vk { + // template <class T> using Foo = __hlsl_spirv_t; + // typedef Foo + template <typename T, T v> struct integral_constant { + static constexpr T value = v; + }; + + template <typename T> struct Literal {}; + + template <uint Opcode, uint Size, uint Alignment, typename... Operands> + using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>; + + template <uint Opcode, typename... Operands> + using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>; + } // namespace vk + } // namespace hlsl + +#endif // _HLSL_HLSL_SPIRV_H_ \ No newline at end of file diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index e7f418ae6802e..b27c884bb0d8c 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4478,6 +4478,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T, case Type::ObjCTypeParam: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: T = cast<ElaboratedType>(Ty)->getNamedType(); diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp index a77ca779a9ee3..a1f5d673c1b5a 100644 --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -923,13 +923,25 @@ bool Sema::LookupBuiltin(LookupResult &R) { NameKind == Sema::LookupRedeclarationWithLinkage) { IdentifierInfo *II = R.getLookupName().getAsIdentifierInfo(); if (II) { - if (getLangOpts().CPlusPlus && NameKind == Sema::LookupOrdinaryName) { -#define BuiltinTemplate(BIName) \ + if (NameKind == Sema::LookupOrdinaryName) { + if (getLangOpts().CPlusPlus) { +#define BuiltinTemplate(BIName) +#define CPlusPlusBuiltinTemplate(BIName) \ if (II == getASTContext().get##BIName##Name()) { \ R.addDecl(getASTContext().get##BIName##Decl()); \ return true; \ } #include "clang/Basic/BuiltinTemplates.inc" + } + if (getLangOpts().HLSL) { +#define BuiltinTemplate(BIName) +#define HLSLBuiltinTemplate(BIName) \ + if (II == getASTContext().get##BIName##Name()) { \ + R.addDecl(getASTContext().get##BIName##Decl()); \ + return true; \ + } +#include "clang/Basic/BuiltinTemplates.inc" + } } // Check if this is an OpenCL Builtin, and if so, insert its overloads. @@ -3265,6 +3277,11 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) { case Type::HLSLAttributedResource: T = cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr(); + break; + + // Inline SPIR-V types are treated as fundamental types. + case Type::HLSLInlineSpirv: + break; } if (Queue.empty()) diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp index 1f87ef4b27bab..baab78327dc6d 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -3228,6 +3228,62 @@ static QualType builtinCommonTypeImpl(Sema &S, TemplateName BaseTemplate, } } +static bool isInVkNamespace(const RecordType *RT) { + DeclContext *DC = RT->getDecl()->getDeclContext(); + if (!DC) + return false; + + NamespaceDecl *ND = dyn_cast<NamespaceDecl>(DC); + if (!ND) + return false; + + return ND->getQualifiedNameAsString() == "hlsl::vk"; +} + +static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef, + QualType OperandArg, + SourceLocation Loc) { + if (auto *RT = OperandArg->getAs<RecordType>()) { + bool Literal = false; + SourceLocation LiteralLoc; + if (isInVkNamespace(RT) && RT->getDecl()->getName() == "Literal") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &LiteralArgs = SpecDecl->getTemplateArgs(); + QualType ConstantType = LiteralArgs[0].getAsType(); + RT = ConstantType->getAs<RecordType>(); + Literal = true; + LiteralLoc = SpecDecl->getSourceRange().getBegin(); + } + + if (RT && isInVkNamespace(RT) && + RT->getDecl()->getName() == "integral_constant") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &ConstantArgs = SpecDecl->getTemplateArgs(); + + QualType ConstantType = ConstantArgs[0].getAsType(); + llvm::APInt Value = ConstantArgs[1].getAsIntegral(); + + if (Literal) { + return SpirvOperand::createLiteral(Value); + } else { + return SpirvOperand::createConstant(ConstantType, Value); + } + } else if (Literal) { + SemaRef.Diag(LiteralLoc, diag::err_hlsl_vk_literal_must_contain_constant); + return SpirvOperand(); + } + } + if (SemaRef.RequireCompleteType(Loc, OperandArg, + diag::err_call_incomplete_argument)) { + return SpirvOperand(); + } + return SpirvOperand::createType(OperandArg); +} + static QualType checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, ArrayRef<TemplateArgument> Converted, @@ -3289,7 +3345,7 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, // __type_pack_element<Index, T_1, ..., T_N> // are treated like T_Index. assert(Converted.size() == 2 && - "__type_pack_element should be given an index and a parameter pack"); + "__type_pack_element should be given an index and a parameter pack"); TemplateArgument IndexArg = Converted[0], Ts = Converted[1]; if (IndexArg.isDependent() || Ts.isDependent()) @@ -3332,6 +3388,39 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, } return HasNoTypeMember; } + + case BTK__hlsl_spirv_type: { + assert(Converted.size() == 4); + + if (!Context.getTargetInfo().getTriple().isSPIRV()) { + SemaRef.Diag(TemplateLoc, diag::err_hlsl_spirv_only) + << "__hlsl_spirv_type"; + } + + if (llvm::any_of(Converted, [](auto &C) { return C.isDependent(); })) + return Context.getCanonicalTemplateSpecializationType(TemplateName(BTD), + Converted); + + uint64_t Opcode = Converted[0].getAsIntegral().getZExtValue(); + uint64_t Size = Converted[1].getAsIntegral().getZExtValue(); + uint64_t Alignment = Converted[2].getAsIntegral().getZExtValue(); + + ArrayRef<TemplateArgument> OperandArgs = Converted[3].getPackAsArray(); + + llvm::SmallVector<SpirvOperand> Operands; + + for (auto &OperandTA : OperandArgs) { + QualType OperandArg = OperandTA.getAsType(); + auto Operand = checkHLSLSpirvTypeOperand(SemaRef, OperandArg, + TemplateArgs[3].getLocation()); + if (!Operand.isValid()) { + return QualType(); + } + Operands.push_back(Operand); + } + + return Context.getHLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + } } llvm_unreachable("unexpected BuiltinTemplateDecl!"); } @@ -6165,6 +6254,18 @@ bool UnnamedLocalNoLinkageFinder::VisitHLSLAttributedResourceType( return Visit(T->getWrappedType()); } +bool UnnamedLocalNoLinkageFinder::VisitHLSLInlineSpirvType( + const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) { + if (Operand.isConstant() && Operand.isLiteral()) { + if (Visit(Operand.getResultType())) { + return true; + } + } + } + return false; +} + bool Sema::CheckTemplateArgument(TypeSourceInfo *ArgInfo) { assert(ArgInfo && "invalid TypeSourceInfo"); QualType Arg = ArgInfo->getType(); diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp index 9969f1762fe36..147e078539a5d 100644 --- a/clang/lib/Sema/SemaTemplateDeduction.cpp +++ b/clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2492,6 +2492,7 @@ static TemplateDeductionResult DeduceTemplateArgumentsByTypeMatch( case Type::Pipe: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // No template argument deduction for these types return TemplateDeductionResult::Success; @@ -7116,6 +7117,7 @@ MarkUsedTemplateParameters(ASTContext &Ctx, QualType T, case Type::UnresolvedUsing: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: #define TYPE(Class, Base) #define ABSTRACT_TYPE(Class, Base) #define DEPENDENT_TYPE(Class, Base) diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp index 2df961a48c7c3..e3070675fa2c1 100644 --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -5874,6 +5874,7 @@ namespace { Visit(TL.getWrappedLoc()); fillHLSLAttributedResourceTypeLoc(TL, State); } + void VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {} void VisitMacroQualifiedTypeLoc(MacroQualifiedTypeLoc TL) { Visit(TL.getInnerLoc()); TL.setExpansionLoc( diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 916b8e2735cd0..f3bbaf78ceddf 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -7652,6 +7652,13 @@ QualType TreeTransform<Derived>::TransformHLSLAttributedResourceType( return Result; } +template <typename Derived> +QualType TreeTransform<Derived>::TransformHLSLInlineSpirvType( + TypeLocBuilder &TLB, HLSLInlineSpirvTypeLoc TL) { + // No transformations needed. + return TL.getType(); +} + template<typename Derived> QualType TreeTransform<Derived>::TransformParenType(TypeLocBuilder &TLB, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 58a57d6c54523..7b2bb7b00fbff 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -7284,6 +7284,10 @@ void TypeLocReader::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocReader::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocReader::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { TL.setNameLoc(readSourceLocation()); } @@ -9753,6 +9757,11 @@ TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() { return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool()); } +SpirvOperand ASTRecordReader::readHLSLSpirvOperand() { + return SpirvOperand(SpirvOperand::SpirvOperandKind(readInt()), readQualType(), + readAPInt()); +} + void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) { Info.QualifierLoc = readNestedNameSpecifierLoc(); unsigned NumTPLists = readInt(); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index 84f7f2bc5fce4..ac6647632c4ea 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -604,6 +604,10 @@ void TypeLocWriter::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocWriter::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocWriter::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { addSourceLocation(TL.getNameLoc()); } diff --git a/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl new file mode 100644 index 0000000000000..326c75dbc3bbe --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers; diff --git a/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl new file mode 100644 index 0000000000000..f9aaf368ac935 --- /dev/null +++ b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s + +// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:4:1, col:83> col:83 referenced AType 'vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>>':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +typedef vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>> AType; +// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:6:1, col:133> col:133 referenced BType 'vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 4886718345L>, float, vk::Literal<vk::integral_constant<uint, 456>>>':'__hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>' +typedef vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 0x123456789>, float, vk::Literal<vk::integral_constant<uint, 456>>> BType; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:9:1, col:7> col:7 AValue 'hlsl_constant AType':'hlsl_constant __hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +AType AValue; +// CHECK: VarDecl 0x{{.+}} <{{.+}}:11:1, col:7> col:7 BValue 'hlsl_constant BType':'hlsl_constant __hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>' +BType BValue; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:14:1, col:80> col:80 CValue 'hlsl_constant vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 305419896>>>':'hlsl_constant __hlsl_spirv_type<123, 0, 0, vk::Literal<vk::integral_constant<uint, 305419896>>>' +vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 0x12345678>>> CValue; + +// CHECK: TypeAliasDecl 0x{{.+}} <{{.+}}:18:1, col:72> col:7 Array 'vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>':'__hlsl_spirv_type<28, 0, 0, type-parameter-0-0, integral_constant<unsigned int, L>>' +template <class T, uint L> +using Array = vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:21:1, col:16> col:16 DValue 'hlsl_constant Array<uint, 5>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, uint, vk::integral_constant<unsigned int, 5>>' +Array<uint, 5> DValue; + +[numthreads(1, 1, 1)] +void main() { +// CHECK: VarDecl 0x{{.+}} <col:5, col:11> col:11 EValue 'AType':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' + AType EValue; +} diff --git a/clang/test/AST/HLSL/pch_spirv_type.hlsl b/clang/test/AST/HLSL/pch_spirv_type.hlsl new file mode 100644 index 0000000000000..045f89a1b8461 --- /dev/null +++ b/clang/test/AST/HLSL/pch_spirv_type.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch_spirv_type.hlsl +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a SpirvType in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:92> col:92 buffers2 'hlsl_constant vk::SpirvOpaqueType<28, RWBuffer<float>, vk::integral_constant<uint, 4>>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'vector<float, 2>' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (*)(float2, float2)' <FunctionToPointerDecay> +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/vector-alias.hlsl b/clang/test/AST/HLSL/vector-alias.hlsl index 58d80e9b4a4e4..e1f78e6abdca8 100644 --- a/clang/test/AST/HLSL/vector-alias.hlsl +++ b/clang/test/AST/HLSL/vector-alias.hlsl @@ -1,53 +1,52 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s - -// CHECK: NamespaceDecl {{.*}} implicit hlsl -// CHECK-NEXT: TypeAliasTemplateDecl {{.*}} implicit vector -// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element -// CHECK-NEXT: TemplateArgument type 'float' -// CHECK-NEXT: BuiltinType {{.*}} 'float' -// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count -// CHECK-NEXT: TemplateArgument expr -// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4 -// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>' -// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent -// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0 -// CHECK-NEXT: TemplateTypeParm {{.*}} 'element' -// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue -// NonTypeTemplateParm {{.*}} 'element_count' 'int' - -// Make sure we got a using directive at the end. -// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl' - -[numthreads(1,1,1)] -int entry() { - // Verify that the alias is generated inside the hlsl namespace. - hlsl::vector<float, 2> Vec2 = {1.0, 2.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit - - // Verify that you don't need to specify the namespace. - vector<int, 2> Vec2a = {1, 2}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit - - // Build a bigger vector. - vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit - - // Verify that swizzles still work. - vector<double, 3> Vec3 = Vec4.xyz; - - // CHECK: DeclStmt {{.*}} - // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit - - // Verify that the implicit arguments generate the correct type. - vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit - return 1; -} +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s +// CHECK: NamespaceDecl {{.*}} implicit hlsl +// CHECK: TypeAliasTemplateDecl {{.*}} implicit vector +// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element +// CHECK-NEXT: TemplateArgument type 'float' +// CHECK-NEXT: BuiltinType {{.*}} 'float' +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count +// CHECK-NEXT: TemplateArgument expr +// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4 +// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>' +// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent +// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0 +// CHECK-NEXT: TemplateTypeParm {{.*}} 'element' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue +// NonTypeTemplateParm {{.*}} 'element_count' 'int' + +// Make sure we got a using directive at the end. +// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl' + +[numthreads(1,1,1)] +int entry() { + // Verify that the alias is generated inside the hlsl namespace. + hlsl::vector<float, 2> Vec2 = {1.0, 2.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit + + // Verify that you don't need to specify the namespace. + vector<int, 2> Vec2a = {1, 2}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit + + // Build a bigger vector. + vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit + + // Verify that swizzles still work. + vector<double, 3> Vec3 = Vec4.xyz; + + // CHECK: DeclStmt {{.*}} + // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit + + // Verify that the implicit arguments generate the correct type. + vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit + return 1; +} diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl new file mode 100644 index 0000000000000..4cd8f2bf914aa --- /dev/null +++ b/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl @@ -0,0 +1,16 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s + +using Int = vk::SpirvType</* OpTypeInt */ 21, 4, 64, vk::Literal<vk::integral_constant<uint, 8>>, vk::Literal<vk::integral_constant<bool, false>>>; + +// CHECK: %struct.S = type <{ i32, [4 x i8], target("spirv.Type", target("spirv.Literal", 8), target("spirv.Literal", 0), 21, 4, 64), [8 x i8] }> +struct S { + int a; + Int b; +}; + +[numthreads(1,1,1)] +void main() { + S value; +} diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl new file mode 100644 index 0000000000000..8c7140689ce74 --- /dev/null +++ b/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl @@ -0,0 +1,12 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.0-compute %s \ +// RUN: -fsyntax-only -verify + +typedef vk::SpirvType<12, 2, 4, float> InvalidType1; // expected-error {{use of undeclared identifier 'vk'}} +vk::Literal<nullptr> Unused; // expected-error {{use of undeclared identifier 'vk'}} +vk::integral_constant<uint, 456> Unused2; // expected-error {{use of undeclared identifier 'vk'}} +typedef vk::SpirvOpaqueType<12, float> InvalidType2; // expected-error {{use of undeclared identifier 'vk'}} + +[numthreads(1, 1, 1)] +void main() { +} diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.hlsl new file mode 100644 index 0000000000000..ea013c62899f8 --- /dev/null +++ b/clang/test/CodeGenHLSL/inline/SpirvType.hlsl @@ -0,0 +1,68 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s + +template<class T, uint64_t Size> +using Array = vk::SpirvOpaqueType</* OpTypeArray */ 28, T, vk::integral_constant<uint64_t, Size>>; + +template<uint64_t Size> +using ArrayBuffer = Array<RWBuffer<float>, Size>; + +typedef vk::SpirvType</* OpTypeInt */ 21, 4, 32, vk::Literal<vk::integral_constant<uint, 32>>, vk::Literal<vk::integral_constant<bool, false>>> Int; + +typedef Array<Int, 5> ArrayInt; + +// CHECK: %struct.S = type { target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) } +struct S { + ArrayBuffer<4> b; + Int i; +}; + +// CHECK: define spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %v) #0 +ArrayBuffer<4> getArrayBuffer(ArrayBuffer<4> v) { + return v; +} + +// CHECK: define spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %v) #0 +Int getInt(Int v) { + return v; +} + +// TODO: uncomment and test once CBuffer handles are implemented for SPIR-V +// ArrayBuffer<4> g_buffers; +// Int g_word; + +[numthreads(1, 1, 1)] +void main() { + // CHECK: %buffers = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4 + ArrayBuffer<4> buffers; + + // CHECK: %longBuffers = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 591751049, 1), 28, 0, 0), align 4 + ArrayBuffer<0x123456789> longBuffers; + + // CHECK: %word = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4 + Int word; + + // CHECK: %words = alloca [4 x target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32)], align 4 + Int words[4]; + + // CHECK: %words2 = alloca target("spirv.Type", target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), target("spirv.IntegralConstant", i64, 5), 28, 0, 0), align 4 + ArrayInt words2; + + // CHECK: %value = alloca %struct.S, align 4 + S value; + + // CHECK: %buffers2 = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4 + // CHECK: %word2 = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4 + + + // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), ptr %buffers, align 4 + // CHECK: %call1 = call spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) [[loaded]]) + // CHECK: store target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %call1, ptr %buffers2, align 4 + ArrayBuffer<4> buffers2 = getArrayBuffer(buffers); + + // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), ptr %word, align 4 + // CHECK: %call2 = call spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) [[loaded]]) + // CHECK: store target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %call2, ptr %word2, align 4 + Int word2 = getInt(word); +} diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl new file mode 100644 index 0000000000000..9f4596e6974ee --- /dev/null +++ b/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify + +struct S; // expected-note {{forward declaration of 'S'}} + +// expected-error@hlsl/hlsl_spirv.h:26 {{argument type 'S' is incomplete}} + +typedef vk::SpirvOpaqueType</* OpTypeArray */ 28, S, vk::integral_constant<uint, 4>> ArrayOfS; // #1 +// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}} + +[numthreads(1, 1, 1)] +void main() { + ArrayOfS buffers; +} diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl new file mode 100644 index 0000000000000..44d7e855ba5cd --- /dev/null +++ b/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl @@ -0,0 +1,11 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify + +// expected-error@hlsl/hlsl_spirv.h:20 {{the argument to vk::Literal must be a vk::integral_constant}} + +typedef vk::SpirvOpaqueType<28, vk::Literal<float>> T; // #1 +// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}} + +[numthreads(1, 1, 1)] +void main() { +} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 197ba2cd6856e..abbc6a7ccb6eb 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -1794,6 +1794,11 @@ bool CursorVisitor::VisitHLSLAttributedResourceTypeLoc( return Visit(TL.getWrappedLoc()); } +bool CursorVisitor::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. + return false; +} + bool CursorVisitor::VisitFunctionTypeLoc(FunctionTypeLoc TL, bool SkipResultType) { if (!SkipResultType && Visit(TL.getReturnLoc())) diff --git a/clang/tools/libclang/CXType.cpp b/clang/tools/libclang/CXType.cpp index 2c9ef282b8abc..225790a5ffd80 100644 --- a/clang/tools/libclang/CXType.cpp +++ b/clang/tools/libclang/CXType.cpp @@ -636,6 +636,7 @@ CXString clang_getTypeKindSpelling(enum CXTypeKind K) { TKIND(Attributed); TKIND(BTFTagAttributed); TKIND(HLSLAttributedResource); + TKIND(HLSLInlineSpirv); TKIND(BFloat16); #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) TKIND(Id); #include "clang/Basic/OpenCLImageTypes.def" diff --git a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp index 34bc782e007d5..797e6c3f4d04b 100644 --- a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp +++ b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TableGenBackends.h" +#include "llvm/ADT/StringSet.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -21,11 +22,14 @@ using namespace llvm; static std::string TemplateNameList; static std::string CreateBuiltinTemplateParameterList; +static llvm::StringSet BuiltinClasses; + namespace { struct ParserState { size_t UniqueCounter = 0; size_t CurrentDepth = 0; bool EmittedSizeTInfo = false; + bool EmittedUint32TInfo = false; }; std::pair<std::string, std::string> @@ -62,7 +66,7 @@ ParseTemplateParameterList(ParserState &PS, if (TemplateNameToParmName.find(Type.str()) == TemplateNameToParmName.end()) { - PrintFatalError("Unkown Type Name"); + PrintFatalError("Unknown Type Name"); } auto TSIName = "TSI" + std::to_string(PS.UniqueCounter++); @@ -75,19 +79,32 @@ ParseTemplateParameterList(ParserState &PS, << TSIName << "->getType(), " << Arg->getValueAsBit("IsVariadic") << ", " << TSIName << ");\n"; } else if (Arg->isSubClassOf("BuiltinNTTP")) { - if (Arg->getValueAsString("TypeName") != "size_t") - PrintFatalError("Unkown Type Name"); - if (!PS.EmittedSizeTInfo) { - Code << "TypeSourceInfo *SizeTInfo = " - "C.getTrivialTypeSourceInfo(C.getSizeType());\n"; - PS.EmittedSizeTInfo = true; + std::string SourceInfo; + if (Arg->getValueAsString("TypeName") == "size_t") { + SourceInfo = "SizeTInfo"; + if (!PS.EmittedSizeTInfo) { + Code << "TypeSourceInfo *SizeTInfo = " + "C.getTrivialTypeSourceInfo(C.getSizeType());\n"; + PS.EmittedSizeTInfo = true; + } + } else if (Arg->getValueAsString("TypeName") == "uint32_t") { + SourceInfo = "Uint32TInfo"; + if (!PS.EmittedUint32TInfo) { + Code << "TypeSourceInfo *Uint32TInfo = " + "C.getTrivialTypeSourceInfo(C.UnsignedIntTy);\n"; + PS.EmittedUint32TInfo = true; + } + } else { + PrintFatalError("Unknown Type Name"); } Code << " auto *" << ParmName << " = NonTypeTemplateParmDecl::Create(C, DC, SourceLocation(), " "SourceLocation(), " - << PS.CurrentDepth << ", " << Position++ - << ", /*Id=*/nullptr, SizeTInfo->getType(), " - "/*ParameterPack=*/false, SizeTInfo);\n"; + << PS.CurrentDepth << ", " << Position++ << ", /*Id=*/nullptr, " + << SourceInfo + << "->getType(), " + "/*ParameterPack=*/false, " + << SourceInfo << ");\n"; } else { PrintFatalError("Unknown Argument Type"); } @@ -134,7 +151,8 @@ EmitCreateBuiltinTemplateParameterList(std::vector<const Record *> TemplateArgs, CreateBuiltinTemplateParameterList += " }\n"; } -void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) { +void EmitBuiltinTemplate(const Record *BuiltinTemplate) { + auto Class = BuiltinTemplate->getType()->getAsString(); auto Name = BuiltinTemplate->getName(); std::vector<const Record *> TemplateHead = @@ -142,21 +160,49 @@ void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) { EmitCreateBuiltinTemplateParameterList(TemplateHead, Name); - TemplateNameList += "BuiltinTemplate("; + TemplateNameList += Class + "("; TemplateNameList += Name; TemplateNameList += ")\n"; + + BuiltinClasses.insert(Class); +} + +void EmitDefaultDefine(llvm::raw_ostream &OS, StringRef Name) { + OS << "#ifndef " << Name << "\n"; + OS << "#define " << Name << "(NAME)" << " " << "BuiltinTemplate" + << "(NAME)\n"; + OS << "#endif\n\n"; +} + +void EmitUndef(llvm::raw_ostream &OS, StringRef Name) { + OS << "#undef " << Name << "\n"; } } // namespace void clang::EmitClangBuiltinTemplates(const llvm::RecordKeeper &Records, llvm::raw_ostream &OS) { emitSourceFileHeader("Tables and code for Clang's builtin templates", OS); + for (const auto *Builtin : Records.getAllDerivedDefinitions("BuiltinTemplate")) - EmitBuiltinTemplate(OS, Builtin); + EmitBuiltinTemplate(Builtin); + + for (const auto &ClassEntry : BuiltinClasses) { + StringRef Class = ClassEntry.getKey(); + if (Class == "BuiltinTemplate") + continue; + EmitDefaultDefine(OS, Class); + } OS << "#if defined(CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST)\n" << CreateBuiltinTemplateParameterList << "#undef CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST\n#else\n" << TemplateNameList << "#undef BuiltinTemplate\n#endif\n"; + + for (const auto &ClassEntry : BuiltinClasses) { + StringRef Class = ClassEntry.getKey(); + if (Class == "BuiltinTemplate") + continue; + EmitUndef(OS, Class); + } } >From 7bdc924c57eb4b7e9e28143eff87086310e7b080 Mon Sep 17 00:00:00 2001 From: Cassandra Beckley <cbeck...@google.com> Date: Tue, 1 Apr 2025 23:34:31 -0700 Subject: [PATCH 2/2] Fix formatting; remove unused code --- clang/lib/Headers/hlsl/hlsl_spirv.h | 30 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h index 8a71699a4ed5c..711da2fea46a4 100644 --- a/clang/lib/Headers/hlsl/hlsl_spirv.h +++ b/clang/lib/Headers/hlsl/hlsl_spirv.h @@ -10,21 +10,19 @@ #define _HLSL_HLSL_SPIRV_H_ namespace hlsl { - namespace vk { - // template <class T> using Foo = __hlsl_spirv_t; - // typedef Foo - template <typename T, T v> struct integral_constant { - static constexpr T value = v; - }; - - template <typename T> struct Literal {}; - - template <uint Opcode, uint Size, uint Alignment, typename... Operands> - using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>; - - template <uint Opcode, typename... Operands> - using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>; - } // namespace vk - } // namespace hlsl +namespace vk { +template <typename T, T v> struct integral_constant { + static constexpr T value = v; +}; + +template <typename T> struct Literal {}; + +template <uint Opcode, uint Size, uint Alignment, typename... Operands> +using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>; + +template <uint Opcode, typename... Operands> +using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>; +} // namespace vk +} // namespace hlsl #endif // _HLSL_HLSL_SPIRV_H_ \ No newline at end of file _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits