https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/121803
>From bee90e659e647df21d1f1e65f95fffefa57eadce Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Mon, 6 Jan 2025 16:21:33 +0000 Subject: [PATCH 1/2] [HLSL] Implement parsing of `RootFlags` - Define the Parser class that will contain all the parsing methods in ParseHLSLRootSignature.h - Implement the dispatch behaviour of Parse and ParseRootElement in ParseHLSLRootSignature.cpp - Define the general in-memory datastructure of a RootElement that will be a union of the various RootElement types - Implement the ParseRootFlags methods in ParseHLSLRootSignature - Define the in-memory represenation of the RootFlag and adds it to the RootElement structure - Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp --- .../clang/Sema/ParseHLSLRootSignature.h | 63 ++++++++ clang/lib/Sema/CMakeLists.txt | 1 + clang/lib/Sema/ParseHLSLRootSignature.cpp | 139 ++++++++++++++++++ clang/unittests/Sema/CMakeLists.txt | 1 + .../Sema/ParseHLSLRootSignatureTest.cpp | 58 ++++++++ .../llvm/Frontend/HLSL/HLSLRootSignature.h | 88 +++++++++++ 6 files changed, 350 insertions(+) create mode 100644 clang/include/clang/Sema/ParseHLSLRootSignature.h create mode 100644 clang/lib/Sema/ParseHLSLRootSignature.cpp create mode 100644 clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h new file mode 100644 index 00000000000000..7d1799e22b515c --- /dev/null +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -0,0 +1,63 @@ +//===--- ParseHLSLRootSignature.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ParseHLSLRootSignature interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H +#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +class Parser { +public: + Parser(StringRef Signature, SmallVector<RootElement> *Elements) + : Buffer(Signature), Elements(Elements) {} + + // Consumes the internal buffer as a list of root elements and will + // emplace their in-memory representation onto the back of Elements. + // + // It will consume until it successfully reaches the end of the buffer, + // or until the first error is encountered. The return value denotes if + // there was a failure. + bool Parse(); + +private: + bool ReportError(); + + // RootElements parse methods + bool ParseRootElement(); + + bool ParseRootFlags(); + // Enum methods + template <typename EnumType> + bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum); + bool ParseRootFlag(RootFlags &Flag); + + // Internal state used when parsing + StringRef Buffer; + SmallVector<RootElement> *Elements; + + StringRef Token; +}; + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 719c3a9312ec15..7141bb42eb4363 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -24,6 +24,7 @@ add_clang_library(clangSema JumpDiagnostics.cpp MultiplexExternalSemaSource.cpp ParsedAttr.cpp + ParseHLSLRootSignature.cpp Scope.cpp ScopeInfo.cpp Sema.cpp diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp new file mode 100644 index 00000000000000..e4592ea1937178 --- /dev/null +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -0,0 +1,139 @@ +#include "clang/Sema/ParseHLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +// TODO: Hook up with Sema to properly report semantic/validation errors +bool Parser::ReportError() { return true; } + +bool Parser::ParseRootFlags() { + // Set to RootFlags::None and skip whitespace to catch when we have RootFlags( + // ) + RootFlags Flags = RootFlags::None; + Buffer = Buffer.drop_while(isspace); + StringLiteral Prefix = ""; + + // Loop until we reach the end of the rootflags + while (!Buffer.starts_with(")")) { + // Trim expected | when more than 1 flag + if (!Buffer.consume_front(Prefix)) + return ReportError(); + Prefix = "|"; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + RootFlags CurFlag; + if (ParseRootFlag(CurFlag)) + return ReportError(); + Flags |= CurFlag; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Flags)); + return false; +} + +template <typename EnumType> +bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum) { + // Retrieve enum + Token = Buffer.take_while([](char C) { return isalnum(C) || C == '_'; }); + Buffer = Buffer.drop_front(Token.size()); + + // Try to get the case-insensitive enum + auto Switch = llvm::StringSwitch<std::optional<EnumType>>(Token); + for (auto Pair : Mapping) + Switch.CaseLower(Pair.first, Pair.second); + auto MaybeEnum = Switch.Default(std::nullopt); + if (!MaybeEnum) + return true; + Enum = *MaybeEnum; + + return false; +} + +bool Parser::ParseRootFlag(RootFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = { + {"0", RootFlags::None}, + {"ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT", + RootFlags::AllowInputAssemblerInputLayout}, + {"DENY_VERTEX_SHADER_ROOT_ACCESS", RootFlags::DenyVertexShaderRootAccess}, + {"DENY_HULL_SHADER_ROOT_ACCESS", RootFlags::DenyHullShaderRootAccess}, + {"DENY_DOMAIN_SHADER_ROOT_ACCESS", RootFlags::DenyDomainShaderRootAccess}, + {"DENY_GEOMETRY_SHADER_ROOT_ACCESS", + RootFlags::DenyGeometryShaderRootAccess}, + {"DENY_PIXEL_SHADER_ROOT_ACCESS", RootFlags::DenyPixelShaderRootAccess}, + {"ALLOW_STREAM_OUTPUT", RootFlags::AllowStreamOutput}, + {"LOCAL_ROOT_SIGNATURE", RootFlags::LocalRootSignature}, + {"DENY_AMPLIFICATION_SHADER_ROOT_ACCESS", + RootFlags::DenyAmplificationShaderRootAccess}, + {"DENY_MESH_SHADER_ROOT_ACCESS", RootFlags::DenyMeshShaderRootAccess}, + {"CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED", + RootFlags::CBVSRVUAVHeapDirectlyIndexed}, + {"SAMPLER_HEAP_DIRECTLY_INDEXED", RootFlags::SamplerHeapDirectlyIndexed}, + {"AllowLowTierReservedHwCbLimit", + RootFlags::AllowLowTierReservedHwCbLimit}, + }; + + return ParseEnum<RootFlags>(Mapping, Flag); +} + +bool Parser::ParseRootElement() { + // Define different ParserMethods to use StringSwitch for dispatch + enum class ParserMethod { + ReportError, + ParseRootFlags, + }; + + // Retreive which method should be used + auto Method = llvm::StringSwitch<ParserMethod>(Token) + .Case("RootFlags", ParserMethod::ParseRootFlags) + .Default(ParserMethod::ReportError); + + // Dispatch on the correct method + switch (Method) { + case ParserMethod::ReportError: + return ReportError(); + case ParserMethod::ParseRootFlags: + return ParseRootFlags(); + } +} + +// Parser entry point function +bool Parser::Parse() { + StringLiteral Prefix = ""; + while (!Buffer.empty()) { + // Trim expected comma when more than 1 root element + if (!Buffer.consume_front(Prefix)) + return ReportError(); + Prefix = ","; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retrieve the root element identifier + auto Split = Buffer.split('('); + Token = Split.first; + Buffer = Split.second; + + // Dispatch to the applicable root element parser + if (ParseRootElement()) + return true; + + // Then we can clean up the remaining ")" + if (!Buffer.consume_front(")")) + return ReportError(); + } + + // All input has been correctly parsed + return false; +} + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm diff --git a/clang/unittests/Sema/CMakeLists.txt b/clang/unittests/Sema/CMakeLists.txt index 7ded562e8edfa5..f382f8b1235306 100644 --- a/clang/unittests/Sema/CMakeLists.txt +++ b/clang/unittests/Sema/CMakeLists.txt @@ -7,6 +7,7 @@ add_clang_unittest(SemaTests ExternalSemaSourceTest.cpp CodeCompleteTest.cpp GslOwnerPointerInference.cpp + ParseHLSLRootSignatureTest.cpp SemaLookupTest.cpp SemaNoloadLookupTest.cpp ) diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp new file mode 100644 index 00000000000000..0e7feb50871669 --- /dev/null +++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp @@ -0,0 +1,58 @@ +//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Sema/ParseHLSLRootSignature.h" +#include "gtest/gtest.h" + +using namespace llvm::hlsl::root_signature; + +namespace { + +TEST(ParseHLSLRootSignature, EmptyRootFlags) { + llvm::StringRef RootFlagString = " RootFlags()"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, RootFlagsNone) { + llvm::StringRef RootFlagString = " RootFlags(0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, ValidRootFlags) { + // Test that the flags are all captured and that they are case insensitive + llvm::StringRef RootFlagString = " RootFlags( " + " ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT" + "| deny_vertex_shader_root_access" + "| DENY_HULL_SHADER_ROOT_ACCESS" + "| deny_domain_shader_root_access" + "| DENY_GEOMETRY_SHADER_ROOT_ACCESS" + "| deny_pixel_shader_root_access" + "| ALLOW_STREAM_OUTPUT" + "| LOCAL_ROOT_SIGNATURE" + "| deny_amplification_shader_root_access" + "| DENY_MESH_SHADER_ROOT_ACCESS" + "| cbv_srv_uav_heap_directly_indexed" + "| SAMPLER_HEAP_DIRECTLY_INDEXED" + "| AllowLowTierReservedHwCbLimit )"; + + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags); +} + +} // anonymous namespace diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h new file mode 100644 index 00000000000000..a17ebffc7a6bf2 --- /dev/null +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -0,0 +1,88 @@ +//===- HLSLRootSignature.h - HLSL Root Signature helper objects -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helper objects for working with HLSL Root +/// Signatures. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H +#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H + +#include <stdint.h> + +#include "llvm/Support/Endian.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +// This is a copy from DebugInfo/CodeView/CodeView.h +#define RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(Class) \ + inline Class operator|(Class a, Class b) { \ + return static_cast<Class>(llvm::to_underlying(a) | \ + llvm::to_underlying(b)); \ + } \ + inline Class operator&(Class a, Class b) { \ + return static_cast<Class>(llvm::to_underlying(a) & \ + llvm::to_underlying(b)); \ + } \ + inline Class operator~(Class a) { \ + return static_cast<Class>(~llvm::to_underlying(a)); \ + } \ + inline Class &operator|=(Class &a, Class b) { \ + a = a | b; \ + return a; \ + } \ + inline Class &operator&=(Class &a, Class b) { \ + a = a & b; \ + return a; \ + } + +// Various enumerations and flags + +enum class RootFlags : uint32_t { + None = 0, + AllowInputAssemblerInputLayout = 0x1, + DenyVertexShaderRootAccess = 0x2, + DenyHullShaderRootAccess = 0x4, + DenyDomainShaderRootAccess = 0x8, + DenyGeometryShaderRootAccess = 0x10, + DenyPixelShaderRootAccess = 0x20, + AllowStreamOutput = 0x40, + LocalRootSignature = 0x80, + DenyAmplificationShaderRootAccess = 0x100, + DenyMeshShaderRootAccess = 0x200, + CBVSRVUAVHeapDirectlyIndexed = 0x400, + SamplerHeapDirectlyIndexed = 0x800, + AllowLowTierReservedHwCbLimit = 0x80000000, + ValidFlags = 0x80000fff +}; +RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootFlags) + +// Define the in-memory layout structures + +struct RootElement { + enum class ElementType { + RootFlags, + }; + + ElementType Tag; + union { + RootFlags Flags; + }; + + // Constructors + RootElement(RootFlags Flags) : Tag(ElementType::RootFlags), Flags(Flags) {} +}; + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H >From 425e8dcc0237f80627c55517cdc1d5f7277a62d2 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Mon, 6 Jan 2025 17:02:07 +0000 Subject: [PATCH 2/2] [HLSL][RootSignature] Implement parsing of `RootParamter`s - Implement the ParseRootParameter methods in ParseHLSLRootSignature - Define the in-memory represenation of the RootFlag and adds it to the RootElement structure - Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp --- .../clang/Sema/ParseHLSLRootSignature.h | 12 ++ clang/lib/Sema/ParseHLSLRootSignature.cpp | 176 ++++++++++++++++++ .../Sema/ParseHLSLRootSignatureTest.cpp | 98 ++++++++++ .../llvm/Frontend/HLSL/HLSLRootSignature.h | 45 +++++ 4 files changed, 331 insertions(+) diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h index 7d1799e22b515c..f06e02800e5f5f 100644 --- a/clang/include/clang/Sema/ParseHLSLRootSignature.h +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H #define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H +#include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -43,11 +44,22 @@ class Parser { bool ParseRootElement(); bool ParseRootFlags(); + bool ParseRootParameter(); + + // Helper methods + bool ParseAssign(); + bool ParseComma(); + bool ParseOptComma(); + bool ParseRegister(Register &); + bool ParseUnsignedInt(uint32_t &Number); + // Enum methods template <typename EnumType> bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, EnumType &Enum); + bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag); bool ParseRootFlag(RootFlags &Flag); + bool ParseVisibility(ShaderVisibility &Visibility); // Internal state used when parsing StringRef Buffer; diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp index e4592ea1937178..4aaf2f34c1eb4c 100644 --- a/clang/lib/Sema/ParseHLSLRootSignature.cpp +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -38,6 +38,148 @@ bool Parser::ParseRootFlags() { return false; } +bool Parser::ParseRootParameter() { + RootParameter Parameter; + Parameter.Type = llvm::StringSwitch<RootType>(Token) + .Case("CBV", RootType::CBV) + .Case("SRV", RootType::SRV) + .Case("UAV", RootType::UAV) + .Case("RootConstants", RootType::Constants); + // Will never reach here as Token was just verified in dispatch + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retreive mandatory num32BitConstant arg for RootConstants + if (Parameter.Type == RootType::Constants) { + if (!Buffer.consume_front("num32BitConstants")) + return ReportError(); + + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Num32BitConstants)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retrieve mandatory register + if (ParseRegister(Parameter.Register)) + return true; + + if (ParseOptComma()) + return ReportError(); + + // Parse common optional space arg + if (Buffer.consume_front("space")) { + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Space)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse common optional visibility arg + if (Buffer.consume_front("visibility")) { + if (ParseAssign()) + return ReportError(); + + if (ParseVisibility(Parameter.Visibility)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retreive optional flags arg for non-RootConstants + if (Parameter.Type != RootType::Constants && Buffer.consume_front("flags")) { + if (ParseAssign()) + return ReportError(); + + if (ParseRootDescriptorFlag(Parameter.Flags)) + return ReportError(); + + // Remove trailing whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Parameter)); + return false; +} + +// Helper Parser methods + +// Parses " = " with varying whitespace +bool Parser::ParseAssign() { + Buffer = Buffer.drop_while(isspace); + if (!Buffer.starts_with('=')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " with varying whitespace +bool Parser::ParseComma() { + if (!Buffer.starts_with(',')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " if possible. When successful we expect another parameter, and +// return no error, otherwise we expect that we should be at the end of the +// root element and return an error if this isn't the case +bool Parser::ParseOptComma() { + if (!ParseComma()) + return false; + Buffer = Buffer.drop_while(isspace); + return !Buffer.starts_with(')'); +} + +bool Parser::ParseRegister(Register &Register) { + // Parse expected register type ('b', 't', 'u', 's') + if (Buffer.empty()) + return ReportError(); + + // Get type character + Token = Buffer.take_front(); + Buffer = Buffer.drop_front(); + + auto MaybeType = llvm::StringSwitch<std::optional<RegisterType>>(Token) + .Case("b", RegisterType::BReg) + .Case("t", RegisterType::TReg) + .Case("u", RegisterType::UReg) + .Case("s", RegisterType::SReg) + .Default(std::nullopt); + if (!MaybeType) + return ReportError(); + Register.ViewType = *MaybeType; + + if (ParseUnsignedInt(Register.Number)) + return ReportError(); + + return false; +} + +// Parses "[0-9+]" as an unsigned int +bool Parser::ParseUnsignedInt(uint32_t &Number) { + StringRef NumString = Buffer.take_while(isdigit); + APInt X = APInt(32, 0); + if (NumString.getAsInteger(/*radix=*/10, X)) + return true; + Number = X.getZExtValue(); + Buffer = Buffer.drop_front(NumString.size()); + return false; +} + template <typename EnumType> bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, EnumType &Enum) { @@ -57,6 +199,18 @@ bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, return false; } +bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = { + {"0", RootDescriptorFlags::None}, + {"DATA_VOLATILE", RootDescriptorFlags::DataVolatile}, + {"DATA_STATIC_WHILE_SET_AT_EXECUTE", + RootDescriptorFlags::DataStaticWhileSetAtExecute}, + {"DATA_STATIC", RootDescriptorFlags::DataStatic}, + }; + + return ParseEnum<RootDescriptorFlags>(Mapping, Flag); +} + bool Parser::ParseRootFlag(RootFlags &Flag) { SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = { {"0", RootFlags::None}, @@ -83,16 +237,36 @@ bool Parser::ParseRootFlag(RootFlags &Flag) { return ParseEnum<RootFlags>(Mapping, Flag); } +bool Parser::ParseVisibility(ShaderVisibility &Visibility) { + SmallVector<std::pair<StringLiteral, ShaderVisibility>> Mapping = { + {"SHADER_VISIBILITY_ALL", ShaderVisibility::All}, + {"SHADER_VISIBILITY_VERTEX", ShaderVisibility::Vertex}, + {"SHADER_VISIBILITY_HULL", ShaderVisibility::Hull}, + {"SHADER_VISIBILITY_DOMAIN", ShaderVisibility::Domain}, + {"SHADER_VISIBILITY_GEOMETRY", ShaderVisibility::Geometry}, + {"SHADER_VISIBILITY_PIXEL", ShaderVisibility::Pixel}, + {"SHADER_VISIBILITY_AMPLIFICATION", ShaderVisibility::Amplification}, + {"SHADER_VISIBILITY_MESH", ShaderVisibility::Mesh}, + }; + + return ParseEnum<ShaderVisibility>(Mapping, Visibility); +} + bool Parser::ParseRootElement() { // Define different ParserMethods to use StringSwitch for dispatch enum class ParserMethod { ReportError, ParseRootFlags, + ParseRootParameter, }; // Retreive which method should be used auto Method = llvm::StringSwitch<ParserMethod>(Token) .Case("RootFlags", ParserMethod::ParseRootFlags) + .Case("RootConstants", ParserMethod::ParseRootParameter) + .Case("CBV", ParserMethod::ParseRootParameter) + .Case("SRV", ParserMethod::ParseRootParameter) + .Case("UAV", ParserMethod::ParseRootParameter) .Default(ParserMethod::ReportError); // Dispatch on the correct method @@ -101,6 +275,8 @@ bool Parser::ParseRootElement() { return ReportError(); case ParserMethod::ParseRootFlags: return ParseRootFlags(); + case ParserMethod::ParseRootParameter: + return ParseRootParameter(); } } diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp index 0e7feb50871669..eefe20293732ad 100644 --- a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp @@ -55,4 +55,102 @@ TEST(ParseHLSLRootSignature, ValidRootFlags) { ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags); } +TEST(ParseHLSLRootSignature, MandatoryRootConstant) { + llvm::StringRef RootFlagString = "RootConstants(num32BitConstants = 4, b42)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, OptionalRootConstant) { + llvm::StringRef RootFlagString = + "RootConstants(num32BitConstants = 4, b42, space = 4, visibility = " + "SHADER_VISIBILITY_DOMAIN)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)4, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Domain, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, DefaultRootCBV) { + llvm::StringRef ViewsString = "CBV(b0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)0, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::None, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootCBV) { + llvm::StringRef ViewsString = "CBV(b982374, space = 1, flags = DATA_STATIC)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)982374, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStatic, Parameter.Flags); + ASSERT_EQ((uint32_t)1, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootSRV) { + llvm::StringRef ViewsString = "SRV(t3, visibility = SHADER_VISIBILITY_MESH, " + "flags = Data_Static_While_Set_At_Execute)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::SRV, Parameter.Type); + ASSERT_EQ(RegisterType::TReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)3, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStaticWhileSetAtExecute, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Mesh, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootUAV) { + llvm::StringRef ViewsString = "UAV(u0, flags = DATA_VOLATILE)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::UAV, Parameter.Type); + ASSERT_EQ(RegisterType::UReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)0, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataVolatile, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} } // anonymous namespace diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index a17ebffc7a6bf2..61fd47bbb48ab1 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -65,20 +65,65 @@ enum class RootFlags : uint32_t { }; RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootFlags) +enum class RootDescriptorFlags : unsigned { + None = 0, + DataVolatile = 0x2, + DataStaticWhileSetAtExecute = 0x4, + DataStatic = 0x8, + ValidFlags = 0xe +}; +RS_DEFINE_ENUM_CLASS_FLAGS_OPERATORS(RootDescriptorFlags) + +enum class ShaderVisibility { + All = 0, + Vertex = 1, + Hull = 2, + Domain = 3, + Geometry = 4, + Pixel = 5, + Amplification = 6, + Mesh = 7, +}; + // Define the in-memory layout structures +// Models the different registers: bReg | tReg | uReg | sReg +enum class RegisterType { BReg, TReg, UReg, SReg }; +struct Register { + RegisterType ViewType; + uint32_t Number; +}; + +// Models RootConstants | RootCBV | RootSRV | RootUAV collecting like +// parameters +enum class RootType { CBV, SRV, UAV, Constants }; +struct RootParameter { + RootType Type; + Register Register; + union { + uint32_t Num32BitConstants; + RootDescriptorFlags Flags = RootDescriptorFlags::None; + }; + uint32_t Space = 0; + ShaderVisibility Visibility = ShaderVisibility::All; +}; + struct RootElement { enum class ElementType { RootFlags, + RootParameter, }; ElementType Tag; union { RootFlags Flags; + RootParameter Parameter; }; // Constructors RootElement(RootFlags Flags) : Tag(ElementType::RootFlags), Flags(Flags) {} + RootElement(RootParameter Parameter) + : Tag(ElementType::RootParameter), Parameter(Parameter) {} }; } // namespace root_signature _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits