https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/120811
None >From 2b73435826c02577398fe96efc89d9efb4df0ef6 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 20 Dec 2024 01:22:06 +0000 Subject: [PATCH 1/4] [HLSL] Define in-memory structure of RootElements --- .../llvm/Frontend/HLSL/HLSLRootSignature.h | 258 ++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h new file mode 100644 index 00000000000000..69d42b45f64b7e --- /dev/null +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -0,0 +1,258 @@ +//===- HLSLRootSignature.h - HLSL Root Signature helper objects -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helper objects for working with HLSL Root +/// Signatures. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H +#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H + +#include <stdint.h> + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Endian.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +// This is a copy from DebugInfo/CodeView/CodeView.h +#define RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(Class) \ + inline Class operator|(Class a, Class b) { \ + return static_cast<Class>(llvm::to_underlying(a) | \ + llvm::to_underlying(b)); \ + } \ + inline Class operator&(Class a, Class b) { \ + return static_cast<Class>(llvm::to_underlying(a) & \ + llvm::to_underlying(b)); \ + } \ + inline Class operator~(Class a) { \ + return static_cast<Class>(~llvm::to_underlying(a)); \ + } \ + inline Class &operator|=(Class &a, Class b) { \ + a = a | b; \ + return a; \ + } \ + inline Class &operator&=(Class &a, Class b) { \ + a = a & b; \ + return a; \ + } + +// Various enumerations and flags + +enum class RootFlags : uint32_t { + None = 0, + AllowInputAssemblerInputLayout = 0x1, + DenyVertexShaderRootAccess = 0x2, + DenyHullShaderRootAccess = 0x4, + DenyDomainShaderRootAccess = 0x8, + DenyGeometryShaderRootAccess = 0x10, + DenyPixelShaderRootAccess = 0x20, + AllowStreamOutput = 0x40, + LocalRootSignature = 0x80, + DenyAmplificationShaderRootAccess = 0x100, + DenyMeshShaderRootAccess = 0x200, + CBVSRVUAVHeapDirectlyIndexed = 0x400, + SamplerHeapDirectlyIndexed = 0x800, + AllowLowTierReservedHwCbLimit = 0x80000000, + ValidFlags = 0x80000fff +}; +RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootFlags) + +enum class RootDescriptorFlags : unsigned { + None = 0, + DataVolatile = 0x2, + DataStaticWhileSetAtExecute = 0x4, + DataStatic = 0x8, + ValidFlags = 0xe +}; +RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootDescriptorFlags) + +enum class DescriptorRangeFlags : unsigned { + None = 0, + DescriptorsVolatile = 0x1, + DataVolatile = 0x2, + DataStaticWhileSetAtExecute = 0x4, + DataStatic = 0x8, + DescriptorsStaticKeepingBufferBoundsChecks = 0x10000, + ValidFlags = 0x1000f, + ValidSamplerFlags = DescriptorsVolatile, +}; +RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(DescriptorRangeFlags) + +enum class ShaderVisibility { + All = 0, + Vertex = 1, + Hull = 2, + Domain = 3, + Geometry = 4, + Pixel = 5, + Amplification = 6, + Mesh = 7, +}; + +enum class Filter { + MinMagMipPoint = 0, + MinMagPointMipLinear = 0x1, + MinPointMagLinearMipPoint = 0x4, + MinPointMagMipLinear = 0x5, + MinLinearMagMipPoint = 0x10, + MinLinearMagPointMipLinear = 0x11, + MinMagLinearMipPoint = 0x14, + MinMagMipLinear = 0x15, + Anisotropic = 0x55, + ComparisonMinMagMipPoint = 0x80, + ComparisonMinMagPointMipLinear = 0x81, + ComparisonMinPointMagLinearMipPoint = 0x84, + ComparisonMinPointMagMipLinear = 0x85, + ComparisonMinLinearMagMipPoint = 0x90, + ComparisonMinLinearMagPointMipLinear = 0x91, + ComparisonMinMagLinearMipPoint = 0x94, + ComparisonMinMagMipLinear = 0x95, + ComparisonAnisotropic = 0xd5, + MinimumMinMagMipPoint = 0x100, + MinimumMinMagPointMipLinear = 0x101, + MinimumMinPointMagLinearMipPoint = 0x104, + MinimumMinPointMagMipLinear = 0x105, + MinimumMinLinearMagMipPoint = 0x110, + MinimumMinLinearMagPointMipLinear = 0x111, + MinimumMinMagLinearMipPoint = 0x114, + MinimumMinMagMipLinear = 0x115, + MinimumAnisotropic = 0x155, + MaximumMinMagMipPoint = 0x180, + MaximumMinMagPointMipLinear = 0x181, + MaximumMinPointMagLinearMipPoint = 0x184, + MaximumMinPointMagMipLinear = 0x185, + MaximumMinLinearMagMipPoint = 0x190, + MaximumMinLinearMagPointMipLinear = 0x191, + MaximumMinMagLinearMipPoint = 0x194, + MaximumMinMagMipLinear = 0x195, + MaximumAnisotropic = 0x1d5 +}; + +enum class TextureAddressMode { + Wrap = 1, + Mirror = 2, + Clamp = 3, + Border = 4, + MirrorOnce = 5 +}; + +enum class ComparisonFunc : unsigned { + Never = 1, + Less = 2, + Equal = 3, + LessEqual = 4, + Greater = 5, + NotEqual = 6, + GreaterEqual = 7, + Always = 8 +}; + +enum class StaticBorderColor { + TransparentBlack = 0, + OpaqueBlack = 1, + OpaqueWhite = 2, + OpaqueBlackUint = 3, + OpaqueWhiteUint = 4 +}; + +// Define the in-memory layout structures + +// Models the different registers: bReg | tReg | uReg | sReg +enum class RegisterType { BReg, TReg, UReg, SReg }; +struct Register { + RegisterType ViewType; + uint32_t Number; +}; + +// Models RootConstants | RootCBV | RootSRV | RootUAV collecting like +// parameters +enum class RootType { CBV, SRV, UAV, Constants }; +struct RootParameter { + RootType Type; + Register Register; + union { + uint32_t Num32BitConstants; + RootDescriptorFlags Flags = RootDescriptorFlags::None; + }; + uint32_t Space = 0; + ShaderVisibility Visibility = ShaderVisibility::All; +}; + +static const uint32_t DescriptorTableOffsetAppend = 0xffffffff; +// Models DTClause : CBV | SRV | UAV | Sampler collecting like parameters +enum class ClauseType { CBV, SRV, UAV, Sampler }; +struct DescriptorTableClause { + ClauseType Type; + Register Register; + uint32_t NumDescriptors = 1; + uint32_t Space = 0; + uint32_t Offset = DescriptorTableOffsetAppend; + DescriptorRangeFlags Flags = DescriptorRangeFlags::None; +}; + +// Models the start of a descriptor table +struct DescriptorTable { + ShaderVisibility Visibility = ShaderVisibility::All; + uint32_t NumClauses = 0; +}; + +struct StaticSampler { + Register Register; + Filter Filter = Filter::Anisotropic; + TextureAddressMode AddressU = TextureAddressMode::Wrap; + TextureAddressMode AddressV = TextureAddressMode::Wrap; + TextureAddressMode AddressW = TextureAddressMode::Wrap; + float MipLODBias = 0.f; + uint32_t MaxAnisotropy = 16; + ComparisonFunc ComparisonFunc = ComparisonFunc::LessEqual; + StaticBorderColor BorderColor = StaticBorderColor::OpaqueWhite; + float MinLOD = 0.f; + float MaxLODBias = 3.402823466e+38f; + uint32_t Space = 0; + ShaderVisibility Visibility = ShaderVisibility::All; +}; + +struct RootElement { + enum class ElementType { + RootFlags, + RootParameter, + DescriptorTable, + DescriptorTableClause, + StaticSampler + }; + + ElementType Tag; + union { + RootFlags Flags; + RootParameter Parameter; + DescriptorTable Table; + DescriptorTableClause Clause; + StaticSampler Sampler; + }; + + // Constructors + RootElement(RootFlags Flags) : Tag(ElementType::RootFlags), Flags(Flags) {} + RootElement(RootParameter Parameter) + : Tag(ElementType::RootParameter), Parameter(Parameter) {} + RootElement(DescriptorTable Table) + : Tag(ElementType::DescriptorTable), Table(Table) {} + RootElement(DescriptorTableClause Clause) + : Tag(ElementType::DescriptorTableClause), Clause(Clause) {} + RootElement(StaticSampler Sampler) + : Tag(ElementType::StaticSampler), Sampler(Sampler) {} +}; + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H >From e5a9e84168a69ccc0e5b02fa839b6cb69f3fb417 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 20 Dec 2024 01:22:26 +0000 Subject: [PATCH 2/4] [HLSL] Implement parsing of `RootFlags` --- .../clang/Sema/ParseHLSLRootSignature.h | 55 +++++++ clang/lib/Sema/CMakeLists.txt | 1 + clang/lib/Sema/ParseHLSLRootSignature.cpp | 149 ++++++++++++++++++ clang/unittests/Sema/CMakeLists.txt | 1 + .../Sema/ParseHLSLRootSignatureTest.cpp | 58 +++++++ 5 files changed, 264 insertions(+) create mode 100644 clang/include/clang/Sema/ParseHLSLRootSignature.h create mode 100644 clang/lib/Sema/ParseHLSLRootSignature.cpp create mode 100644 clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h new file mode 100644 index 00000000000000..3130c71f78d5be --- /dev/null +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -0,0 +1,55 @@ +//===--- ParseHLSLRootSignature.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ParseHLSLRootSignature interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H +#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +class Parser { +public: + Parser(StringRef Signature, SmallVector<RootElement> *Elements) + : Buffer(Signature), Elements(Elements) {} + + bool Parse(); + +private: + bool ReportError(); + + // RootElements parse methods + bool ParseRootElement(); + bool ParseRootFlags(); + + // Enum methods + template <typename EnumType> + bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum); + bool ParseRootFlag(RootFlags &Flag); + + StringRef Buffer; + SmallVector<RootElement> *Elements; + + StringRef Token; +}; + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 719c3a9312ec15..7141bb42eb4363 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -24,6 +24,7 @@ add_clang_library(clangSema JumpDiagnostics.cpp MultiplexExternalSemaSource.cpp ParsedAttr.cpp + ParseHLSLRootSignature.cpp Scope.cpp ScopeInfo.cpp Sema.cpp diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp new file mode 100644 index 00000000000000..f86433dd4d903b --- /dev/null +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -0,0 +1,149 @@ +#include "clang/Sema/ParseHLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +// TODO: Hook up with Sema to properly report semantic/validation errors +bool Parser::ReportError() { return true; } + +bool Parser::ParseRootFlags() { + // Set to RootFlags::None and skip whitespace to catch when we have RootFlags( + // ) + RootFlags Flags = RootFlags::None; + Buffer = Buffer.drop_while(isspace); + bool First = true; + + // Loop until we reach the end of the rootflags + while (!Buffer.starts_with(")")) { + // Trim expected | when more than 1 flag + if (!First && !Buffer.consume_front("|")) + return ReportError(); + First = false; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + RootFlags CurFlag; + if (ParseRootFlag(CurFlag)) + return ReportError(); + Flags |= CurFlag; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Flags)); + return false; +} + +template <typename EnumType> +bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum) { + // Retrieve enum + Token = Buffer.take_while([](char C) { return isalnum(C) || C == '_'; }); + Buffer = Buffer.drop_front(Token.size()); + + // Try to get the case-insensitive enum + auto Switch = llvm::StringSwitch<std::optional<EnumType>>(Token); + for (auto Pair : Mapping) + Switch.CaseLower(Pair.first, Pair.second); + auto MaybeEnum = Switch.Default(std::nullopt); + if (!MaybeEnum) + return true; + Enum = *MaybeEnum; + + return false; +} + +bool Parser::ParseRootFlag(RootFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = { + {"0", RootFlags::None}, + {"ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT", + RootFlags::AllowInputAssemblerInputLayout}, + {"DENY_VERTEX_SHADER_ROOT_ACCESS", RootFlags::DenyVertexShaderRootAccess}, + {"DENY_HULL_SHADER_ROOT_ACCESS", RootFlags::DenyHullShaderRootAccess}, + {"DENY_DOMAIN_SHADER_ROOT_ACCESS", RootFlags::DenyDomainShaderRootAccess}, + {"DENY_GEOMETRY_SHADER_ROOT_ACCESS", + RootFlags::DenyGeometryShaderRootAccess}, + {"DENY_PIXEL_SHADER_ROOT_ACCESS", RootFlags::DenyPixelShaderRootAccess}, + {"ALLOW_STREAM_OUTPUT", RootFlags::AllowStreamOutput}, + {"LOCAL_ROOT_SIGNATURE", RootFlags::LocalRootSignature}, + {"DENY_AMPLIFICATION_SHADER_ROOT_ACCESS", + RootFlags::DenyAmplificationShaderRootAccess}, + {"DENY_MESH_SHADER_ROOT_ACCESS", RootFlags::DenyMeshShaderRootAccess}, + {"CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED", + RootFlags::CBVSRVUAVHeapDirectlyIndexed}, + {"SAMPLER_HEAP_DIRECTLY_INDEXED", RootFlags::SamplerHeapDirectlyIndexed}, + {"AllowLowTierReservedHwCbLimit", + RootFlags::AllowLowTierReservedHwCbLimit}, + }; + + return ParseEnum<RootFlags>(Mapping, Flag); +} + +bool Parser::ParseRootElement() { + // Define different ParserMethods to use StringSwitch for dispatch + enum class ParserMethod { + ReportError, + ParseRootFlags, + }; + + // Retreive which method should be used + auto Method = llvm::StringSwitch<ParserMethod>(Token) + .Case("RootFlags", ParserMethod::ParseRootFlags) + .Default(ParserMethod::ReportError); + + // Dispatch on the correct method + bool Error = false; + switch (Method) { + case ParserMethod::ReportError: + Error = true; + break; + case ParserMethod::ParseRootFlags: + Error = ParseRootFlags(); + break; + case ParserMethod::ParseRootParameter: + Error = ParseRootParameter(); + break; + } + + if (Error) + return ReportError(); + + return false; +} + +bool Parser::Parse() { + bool First = true; + while (!Buffer.empty()) { + // Trim expected comma when more than 1 root element + if (!First && !Buffer.consume_front(",")) + return ReportError(); + First = false; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retrieve the root element identifier + auto Split = Buffer.split('('); + Token = Split.first; + Buffer = Split.second; + + // Dispatch to the applicable root element parser + if (ParseRootElement()) + return ReportError(); + + // Then we can clean up the remaining ")" + if (!Buffer.consume_front(")")) + return ReportError(); + } + + // All input has been correctly parsed + return false; +} + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm diff --git a/clang/unittests/Sema/CMakeLists.txt b/clang/unittests/Sema/CMakeLists.txt index 7ded562e8edfa5..f382f8b1235306 100644 --- a/clang/unittests/Sema/CMakeLists.txt +++ b/clang/unittests/Sema/CMakeLists.txt @@ -7,6 +7,7 @@ add_clang_unittest(SemaTests ExternalSemaSourceTest.cpp CodeCompleteTest.cpp GslOwnerPointerInference.cpp + ParseHLSLRootSignatureTest.cpp SemaLookupTest.cpp SemaNoloadLookupTest.cpp ) diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp new file mode 100644 index 00000000000000..0e7feb50871669 --- /dev/null +++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp @@ -0,0 +1,58 @@ +//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Sema/ParseHLSLRootSignature.h" +#include "gtest/gtest.h" + +using namespace llvm::hlsl::root_signature; + +namespace { + +TEST(ParseHLSLRootSignature, EmptyRootFlags) { + llvm::StringRef RootFlagString = " RootFlags()"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, RootFlagsNone) { + llvm::StringRef RootFlagString = " RootFlags(0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, ValidRootFlags) { + // Test that the flags are all captured and that they are case insensitive + llvm::StringRef RootFlagString = " RootFlags( " + " ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT" + "| deny_vertex_shader_root_access" + "| DENY_HULL_SHADER_ROOT_ACCESS" + "| deny_domain_shader_root_access" + "| DENY_GEOMETRY_SHADER_ROOT_ACCESS" + "| deny_pixel_shader_root_access" + "| ALLOW_STREAM_OUTPUT" + "| LOCAL_ROOT_SIGNATURE" + "| deny_amplification_shader_root_access" + "| DENY_MESH_SHADER_ROOT_ACCESS" + "| cbv_srv_uav_heap_directly_indexed" + "| SAMPLER_HEAP_DIRECTLY_INDEXED" + "| AllowLowTierReservedHwCbLimit )"; + + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags); +} + +} // anonymous namespace >From 7ee7154cdc3959822d184e04127176e5f81fc436 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 20 Dec 2024 20:15:36 +0000 Subject: [PATCH 3/4] [HLSL] Implement parsing of `RootParameters` --- .../clang/Sema/ParseHLSLRootSignature.h | 11 ++ clang/lib/Sema/ParseHLSLRootSignature.cpp | 174 ++++++++++++++++++ .../Sema/ParseHLSLRootSignatureTest.cpp | 98 ++++++++++ 3 files changed, 283 insertions(+) diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h index 3130c71f78d5be..6e5ef0855249ed 100644 --- a/clang/include/clang/Sema/ParseHLSLRootSignature.h +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H #define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H +#include "llvm/ADT/APInt.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -35,12 +36,22 @@ class Parser { // RootElements parse methods bool ParseRootElement(); bool ParseRootFlags(); + bool ParseRootParameter(); + + // Helper methods + bool ParseAssign(); + bool ParseComma(); + bool ParseOptComma(); + bool ParseRegister(Register &); + bool ParseUnsignedInt(uint32_t &Number); // Enum methods template <typename EnumType> bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, EnumType &Enum); + bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag); bool ParseRootFlag(RootFlags &Flag); + bool ParseVisibility(ShaderVisibility &Visibility); StringRef Buffer; SmallVector<RootElement> *Elements; diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp index f86433dd4d903b..d3a8db32dc5354 100644 --- a/clang/lib/Sema/ParseHLSLRootSignature.cpp +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -38,6 +38,147 @@ bool Parser::ParseRootFlags() { return false; } +bool Parser::ParseRootParameter() { + RootParameter Parameter; + Parameter.Type = llvm::StringSwitch<RootType>(Token) + .Case("CBV", RootType::CBV) + .Case("SRV", RootType::SRV) + .Case("UAV", RootType::UAV) + .Case("RootConstants", RootType::Constants); + // Should never reach as Token was just verified in dispatch + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retreive mandatory num32BitConstant arg for RootConstants + if (Parameter.Type == RootType::Constants) { + if (!Buffer.consume_front("num32BitConstants")) + return ReportError(); + + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Num32BitConstants)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retrieve mandatory register + if (ParseRegister(Parameter.Register)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + + // Parse common optional space arg + if (Buffer.consume_front("space")) { + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Space)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse common optional visibility arg + if (Buffer.consume_front("visibility")) { + if (ParseAssign()) + return ReportError(); + + if (ParseVisibility(Parameter.Visibility)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retreive optional flags arg for non-RootConstants + if (Parameter.Type != RootType::Constants && Buffer.consume_front("flags")) { + if (ParseAssign()) + return ReportError(); + + if (ParseRootDescriptorFlag(Parameter.Flags)) + return ReportError(); + + // Remove trailing whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Parameter)); + return false; +} + +// Helper Parser methods + +// Parses " = " with varying whitespace +bool Parser::ParseAssign() { + Buffer = Buffer.drop_while(isspace); + if (!Buffer.starts_with('=')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " with varying whitespace +bool Parser::ParseComma() { + if (!Buffer.starts_with(',')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " if possible. When successful we expect another parameter, and +// return no error, otherwise we expect that we should be at the end of the +// root element and return an error if this isn't the case +bool Parser::ParseOptComma() { + if (!ParseComma()) + return false; + Buffer = Buffer.drop_while(isspace); + return !Buffer.starts_with(')'); +} + +bool Parser::ParseRegister(Register &Register) { + // Parse expected register type ('b', 't', 'u', 's') + if (Buffer.empty()) + return ReportError(); + + // Get type character + Token = Buffer.take_front(); + Buffer = Buffer.drop_front(); + + auto MaybeType = llvm::StringSwitch<std::optional<RegisterType>>(Token) + .Case("b", RegisterType::BReg) + .Case("t", RegisterType::TReg) + .Case("u", RegisterType::UReg) + .Case("s", RegisterType::SReg) + .Default(std::nullopt); + if (!MaybeType) + return ReportError(); + Register.ViewType = *MaybeType; + + if (ParseUnsignedInt(Register.Number)) + return ReportError(); + + return false; +} + +// Parses "[0-9+]" as an unsigned int +bool Parser::ParseUnsignedInt(uint32_t &Number) { + StringRef NumString = Buffer.take_while(isdigit); + APInt X = APInt(32, 0); + if (NumString.getAsInteger(/*radix=*/10, X)) + return true; + Number = X.getZExtValue(); + Buffer = Buffer.drop_front(NumString.size()); + return false; +} + template <typename EnumType> bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, EnumType &Enum) { @@ -57,6 +198,18 @@ bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, return false; } +bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = { + {"0", RootDescriptorFlags::None}, + {"DATA_VOLATILE", RootDescriptorFlags::DataVolatile}, + {"DATA_STATIC_WHILE_SET_AT_EXECUTE", + RootDescriptorFlags::DataStaticWhileSetAtExecute}, + {"DATA_STATIC", RootDescriptorFlags::DataStatic}, + }; + + return ParseEnum<RootDescriptorFlags>(Mapping, Flag); +} + bool Parser::ParseRootFlag(RootFlags &Flag) { SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = { {"0", RootFlags::None}, @@ -83,16 +236,36 @@ bool Parser::ParseRootFlag(RootFlags &Flag) { return ParseEnum<RootFlags>(Mapping, Flag); } +bool Parser::ParseVisibility(ShaderVisibility &Visibility) { + SmallVector<std::pair<StringLiteral, ShaderVisibility>> Mapping = { + {"SHADER_VISIBILITY_ALL", ShaderVisibility::All}, + {"SHADER_VISIBILITY_VERTEX", ShaderVisibility::Vertex}, + {"SHADER_VISIBILITY_HULL", ShaderVisibility::Hull}, + {"SHADER_VISIBILITY_DOMAIN", ShaderVisibility::Domain}, + {"SHADER_VISIBILITY_GEOMETRY", ShaderVisibility::Geometry}, + {"SHADER_VISIBILITY_PIXEL", ShaderVisibility::Pixel}, + {"SHADER_VISIBILITY_AMPLIFICATION", ShaderVisibility::Amplification}, + {"SHADER_VISIBILITY_MESH", ShaderVisibility::Mesh}, + }; + + return ParseEnum<ShaderVisibility>(Mapping, Visibility); +} + bool Parser::ParseRootElement() { // Define different ParserMethods to use StringSwitch for dispatch enum class ParserMethod { ReportError, ParseRootFlags, + ParseRootParameter, }; // Retreive which method should be used auto Method = llvm::StringSwitch<ParserMethod>(Token) .Case("RootFlags", ParserMethod::ParseRootFlags) + .Case("RootConstants", ParserMethod::ParseRootParameter) + .Case("CBV", ParserMethod::ParseRootParameter) + .Case("SRV", ParserMethod::ParseRootParameter) + .Case("UAV", ParserMethod::ParseRootParameter) .Default(ParserMethod::ReportError); // Dispatch on the correct method @@ -115,6 +288,7 @@ bool Parser::ParseRootElement() { return false; } +// Parser entry point function bool Parser::Parse() { bool First = true; while (!Buffer.empty()) { diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp index 0e7feb50871669..eefe20293732ad 100644 --- a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp @@ -55,4 +55,102 @@ TEST(ParseHLSLRootSignature, ValidRootFlags) { ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags); } +TEST(ParseHLSLRootSignature, MandatoryRootConstant) { + llvm::StringRef RootFlagString = "RootConstants(num32BitConstants = 4, b42)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, OptionalRootConstant) { + llvm::StringRef RootFlagString = + "RootConstants(num32BitConstants = 4, b42, space = 4, visibility = " + "SHADER_VISIBILITY_DOMAIN)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)4, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Domain, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, DefaultRootCBV) { + llvm::StringRef ViewsString = "CBV(b0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)0, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::None, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootCBV) { + llvm::StringRef ViewsString = "CBV(b982374, space = 1, flags = DATA_STATIC)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)982374, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStatic, Parameter.Flags); + ASSERT_EQ((uint32_t)1, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootSRV) { + llvm::StringRef ViewsString = "SRV(t3, visibility = SHADER_VISIBILITY_MESH, " + "flags = Data_Static_While_Set_At_Execute)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::SRV, Parameter.Type); + ASSERT_EQ(RegisterType::TReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)3, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStaticWhileSetAtExecute, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Mesh, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootUAV) { + llvm::StringRef ViewsString = "UAV(u0, flags = DATA_VOLATILE)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::UAV, Parameter.Type); + ASSERT_EQ(RegisterType::UReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)0, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataVolatile, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} } // anonymous namespace >From 75f9b4d8deaa5bc4f9b8c07ae79f8539f43c0304 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Sat, 21 Dec 2024 00:16:06 +0000 Subject: [PATCH 4/4] [HLSL] Implement parsing of `DescriptorTable` --- .../clang/Sema/ParseHLSLRootSignature.h | 4 + clang/lib/Sema/ParseHLSLRootSignature.cpp | 128 ++++++++++++++++++ .../llvm/Frontend/HLSL/HLSLRootSignature.h | 19 ++- 3 files changed, 150 insertions(+), 1 deletion(-) diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h index 6e5ef0855249ed..2b0b410d4df52d 100644 --- a/clang/include/clang/Sema/ParseHLSLRootSignature.h +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -35,6 +35,9 @@ class Parser { // RootElements parse methods bool ParseRootElement(); + + bool ParseDescriptorTable(); + bool ParseDescriptorTableClause(); bool ParseRootFlags(); bool ParseRootParameter(); @@ -49,6 +52,7 @@ class Parser { template <typename EnumType> bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, EnumType &Enum); + bool ParseDescriptorRangeFlag(DescriptorRangeFlags &Flag); bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag); bool ParseRootFlag(RootFlags &Flag); bool ParseVisibility(ShaderVisibility &Visibility); diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp index d3a8db32dc5354..36f5873e11853d 100644 --- a/clang/lib/Sema/ParseHLSLRootSignature.cpp +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -7,6 +7,119 @@ namespace root_signature { // TODO: Hook up with Sema to properly report semantic/validation errors bool Parser::ReportError() { return true; } + +bool Parser::ParseDescriptorTable() { + // Init table which will be updated as we add clauses + Elements->push_back(RootElement(DescriptorTable())); + DescriptorTable &Table = Elements->back().Table; + + bool First = true; + while (!Buffer.empty()) { + // Trim expected comma when more than 1 clause element + if (!First && !Buffer.consume_front(",")) + return ReportError(); + First = false; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retrieve the root element identifier + auto Split = Buffer.split('('); + Token = Split.first; + Buffer = Split.second; + + // Dispatch to the applicable clause parser + if (ParseDescriptorTableClause()) + return ReportError(); + + // Then we can clean up the remaining ")" + if (!Buffer.consume_front(")")) + return ReportError(); + + Table.NumClauses++; + } + + if (ParseOptComma()) + return ReportError(); + + if (ParseVisibility(Table.Visibility)) + return ReportError(); + + // All input has been correctly parsed + return false; +} + +bool Parser::ParseDescriptorTableClause() { + auto MaybeType = llvm::StringSwitch<std::optional<ClauseType>>(Token) + .Case("CBV", ClauseType::CBV) + .Case("SRV", ClauseType::SRV) + .Case("UAV", ClauseType::UAV) + .Case("Sampler", ClauseType::Sampler) + .Default(std::nullopt); + if (!MaybeType) + return ReportError(); + DescriptorTableClause Clause(*MaybeType); + + // Retrieve mandatory register + if (ParseRegister(Clause.Register)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + + // Parse optional numDescriptors arg + if (Buffer.consume_front("numDescriptors")) { + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Clause.NumDescriptors)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse optional space arg + if (Buffer.consume_front("space")) { + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Clause.Space)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse optional offset arg + if (Buffer.consume_front("offset")) { + if (ParseAssign()) + return ReportError(); + + // This will either parse a number or the literal + if (Buffer.consume_front_insensitive("DESCRIPTOR_RANGE_OFFSET")) + Clause.Offset = DescriptorTableOffsetAppend; + else if (ParseUnsignedInt(Clause.Offset)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse optional flags arg + if (Buffer.consume_front("flags")) { + if (ParseAssign()) + return ReportError(); + + if (ParseDescriptorRangeFlag(Clause.Flags)) + + if (ParseOptComma()) + return ReportError(); + } + + return false; +} + bool Parser::ParseRootFlags() { // Set to RootFlags::None and skip whitespace to catch when we have RootFlags( // ) @@ -198,6 +311,21 @@ bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, return false; } +bool Parser::ParseDescriptorRangeFlag(DescriptorRangeFlags &Flag) { + SmallVector<std::pair<StringLiteral, DescriptorRangeFlags>> Mapping = { + {"0", DescriptorRangeFlags::None}, + {"DESCRIPTORS_VOLATILE", DescriptorRangeFlags::DescriptorsVolatile}, + {"DATA_VOLATILE", DescriptorRangeFlags::DataVolatile}, + {"DATA_STATIC_WHILE_SET_AT_EXECUTE", + DescriptorRangeFlags::DataStaticWhileSetAtExecute}, + {"DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS", + DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks}, + {"DATA_STATIC", DescriptorRangeFlags::DataStatic}, + }; + + return ParseEnum<DescriptorRangeFlags>(Mapping, Flag); +} + bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) { SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = { {"0", RootDescriptorFlags::None}, diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 69d42b45f64b7e..77076557919dec 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -196,7 +196,24 @@ struct DescriptorTableClause { uint32_t NumDescriptors = 1; uint32_t Space = 0; uint32_t Offset = DescriptorTableOffsetAppend; - DescriptorRangeFlags Flags = DescriptorRangeFlags::None; + DescriptorRangeFlags Flags; + + DescriptorTableClause(ClauseType Type) : Type(Type) { + switch (Type) { + case ClauseType::CBV: + Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute; + break; + case ClauseType::SRV: + Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute; + break; + case ClauseType::UAV: + Flags = DescriptorRangeFlags::DataVolatile; + break; + case ClauseType::Sampler: + Flags = DescriptorRangeFlags::None; + break; + } + } }; // Models the start of a descriptor table _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits