llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) <details> <summary>Changes</summary> ``` - Implement the ParseRootParameter methods in ParseHLSLRootSignature - Define the in-memory represenation of the various RootParameters and adds it to the RootElement structure - Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp ``` Part of the work for #<!-- -->120472 --- Patch is 25.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121803.diff 6 Files Affected: - (added) clang/include/clang/Sema/ParseHLSLRootSignature.h (+75) - (modified) clang/lib/Sema/CMakeLists.txt (+1) - (added) clang/lib/Sema/ParseHLSLRootSignature.cpp (+315) - (modified) clang/unittests/Sema/CMakeLists.txt (+1) - (added) clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp (+156) - (added) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+133) ``````````diff diff --git a/clang/include/clang/Sema/ParseHLSLRootSignature.h b/clang/include/clang/Sema/ParseHLSLRootSignature.h new file mode 100644 index 00000000000000..f06e02800e5f5f --- /dev/null +++ b/clang/include/clang/Sema/ParseHLSLRootSignature.h @@ -0,0 +1,75 @@ +//===--- ParseHLSLRootSignature.h -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ParseHLSLRootSignature interface. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H +#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +#include "llvm/Frontend/HLSL/HLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +class Parser { +public: + Parser(StringRef Signature, SmallVector<RootElement> *Elements) + : Buffer(Signature), Elements(Elements) {} + + // Consumes the internal buffer as a list of root elements and will + // emplace their in-memory representation onto the back of Elements. + // + // It will consume until it successfully reaches the end of the buffer, + // or until the first error is encountered. The return value denotes if + // there was a failure. + bool Parse(); + +private: + bool ReportError(); + + // RootElements parse methods + bool ParseRootElement(); + + bool ParseRootFlags(); + bool ParseRootParameter(); + + // Helper methods + bool ParseAssign(); + bool ParseComma(); + bool ParseOptComma(); + bool ParseRegister(Register &); + bool ParseUnsignedInt(uint32_t &Number); + + // Enum methods + template <typename EnumType> + bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum); + bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag); + bool ParseRootFlag(RootFlags &Flag); + bool ParseVisibility(ShaderVisibility &Visibility); + + // Internal state used when parsing + StringRef Buffer; + SmallVector<RootElement> *Elements; + + StringRef Token; +}; + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 719c3a9312ec15..7141bb42eb4363 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -24,6 +24,7 @@ add_clang_library(clangSema JumpDiagnostics.cpp MultiplexExternalSemaSource.cpp ParsedAttr.cpp + ParseHLSLRootSignature.cpp Scope.cpp ScopeInfo.cpp Sema.cpp diff --git a/clang/lib/Sema/ParseHLSLRootSignature.cpp b/clang/lib/Sema/ParseHLSLRootSignature.cpp new file mode 100644 index 00000000000000..4aaf2f34c1eb4c --- /dev/null +++ b/clang/lib/Sema/ParseHLSLRootSignature.cpp @@ -0,0 +1,315 @@ +#include "clang/Sema/ParseHLSLRootSignature.h" + +namespace llvm { +namespace hlsl { +namespace root_signature { + +// TODO: Hook up with Sema to properly report semantic/validation errors +bool Parser::ReportError() { return true; } + +bool Parser::ParseRootFlags() { + // Set to RootFlags::None and skip whitespace to catch when we have RootFlags( + // ) + RootFlags Flags = RootFlags::None; + Buffer = Buffer.drop_while(isspace); + StringLiteral Prefix = ""; + + // Loop until we reach the end of the rootflags + while (!Buffer.starts_with(")")) { + // Trim expected | when more than 1 flag + if (!Buffer.consume_front(Prefix)) + return ReportError(); + Prefix = "|"; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + RootFlags CurFlag; + if (ParseRootFlag(CurFlag)) + return ReportError(); + Flags |= CurFlag; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Flags)); + return false; +} + +bool Parser::ParseRootParameter() { + RootParameter Parameter; + Parameter.Type = llvm::StringSwitch<RootType>(Token) + .Case("CBV", RootType::CBV) + .Case("SRV", RootType::SRV) + .Case("UAV", RootType::UAV) + .Case("RootConstants", RootType::Constants); + // Will never reach here as Token was just verified in dispatch + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retreive mandatory num32BitConstant arg for RootConstants + if (Parameter.Type == RootType::Constants) { + if (!Buffer.consume_front("num32BitConstants")) + return ReportError(); + + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Num32BitConstants)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retrieve mandatory register + if (ParseRegister(Parameter.Register)) + return true; + + if (ParseOptComma()) + return ReportError(); + + // Parse common optional space arg + if (Buffer.consume_front("space")) { + if (ParseAssign()) + return ReportError(); + + if (ParseUnsignedInt(Parameter.Space)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Parse common optional visibility arg + if (Buffer.consume_front("visibility")) { + if (ParseAssign()) + return ReportError(); + + if (ParseVisibility(Parameter.Visibility)) + return ReportError(); + + if (ParseOptComma()) + return ReportError(); + } + + // Retreive optional flags arg for non-RootConstants + if (Parameter.Type != RootType::Constants && Buffer.consume_front("flags")) { + if (ParseAssign()) + return ReportError(); + + if (ParseRootDescriptorFlag(Parameter.Flags)) + return ReportError(); + + // Remove trailing whitespace + Buffer = Buffer.drop_while(isspace); + } + + // Create and push the root element on the parsed elements + Elements->push_back(RootElement(Parameter)); + return false; +} + +// Helper Parser methods + +// Parses " = " with varying whitespace +bool Parser::ParseAssign() { + Buffer = Buffer.drop_while(isspace); + if (!Buffer.starts_with('=')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " with varying whitespace +bool Parser::ParseComma() { + if (!Buffer.starts_with(',')) + return true; + Buffer = Buffer.drop_front(); + Buffer = Buffer.drop_while(isspace); + return false; +} + +// Parses ", " if possible. When successful we expect another parameter, and +// return no error, otherwise we expect that we should be at the end of the +// root element and return an error if this isn't the case +bool Parser::ParseOptComma() { + if (!ParseComma()) + return false; + Buffer = Buffer.drop_while(isspace); + return !Buffer.starts_with(')'); +} + +bool Parser::ParseRegister(Register &Register) { + // Parse expected register type ('b', 't', 'u', 's') + if (Buffer.empty()) + return ReportError(); + + // Get type character + Token = Buffer.take_front(); + Buffer = Buffer.drop_front(); + + auto MaybeType = llvm::StringSwitch<std::optional<RegisterType>>(Token) + .Case("b", RegisterType::BReg) + .Case("t", RegisterType::TReg) + .Case("u", RegisterType::UReg) + .Case("s", RegisterType::SReg) + .Default(std::nullopt); + if (!MaybeType) + return ReportError(); + Register.ViewType = *MaybeType; + + if (ParseUnsignedInt(Register.Number)) + return ReportError(); + + return false; +} + +// Parses "[0-9+]" as an unsigned int +bool Parser::ParseUnsignedInt(uint32_t &Number) { + StringRef NumString = Buffer.take_while(isdigit); + APInt X = APInt(32, 0); + if (NumString.getAsInteger(/*radix=*/10, X)) + return true; + Number = X.getZExtValue(); + Buffer = Buffer.drop_front(NumString.size()); + return false; +} + +template <typename EnumType> +bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping, + EnumType &Enum) { + // Retrieve enum + Token = Buffer.take_while([](char C) { return isalnum(C) || C == '_'; }); + Buffer = Buffer.drop_front(Token.size()); + + // Try to get the case-insensitive enum + auto Switch = llvm::StringSwitch<std::optional<EnumType>>(Token); + for (auto Pair : Mapping) + Switch.CaseLower(Pair.first, Pair.second); + auto MaybeEnum = Switch.Default(std::nullopt); + if (!MaybeEnum) + return true; + Enum = *MaybeEnum; + + return false; +} + +bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = { + {"0", RootDescriptorFlags::None}, + {"DATA_VOLATILE", RootDescriptorFlags::DataVolatile}, + {"DATA_STATIC_WHILE_SET_AT_EXECUTE", + RootDescriptorFlags::DataStaticWhileSetAtExecute}, + {"DATA_STATIC", RootDescriptorFlags::DataStatic}, + }; + + return ParseEnum<RootDescriptorFlags>(Mapping, Flag); +} + +bool Parser::ParseRootFlag(RootFlags &Flag) { + SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = { + {"0", RootFlags::None}, + {"ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT", + RootFlags::AllowInputAssemblerInputLayout}, + {"DENY_VERTEX_SHADER_ROOT_ACCESS", RootFlags::DenyVertexShaderRootAccess}, + {"DENY_HULL_SHADER_ROOT_ACCESS", RootFlags::DenyHullShaderRootAccess}, + {"DENY_DOMAIN_SHADER_ROOT_ACCESS", RootFlags::DenyDomainShaderRootAccess}, + {"DENY_GEOMETRY_SHADER_ROOT_ACCESS", + RootFlags::DenyGeometryShaderRootAccess}, + {"DENY_PIXEL_SHADER_ROOT_ACCESS", RootFlags::DenyPixelShaderRootAccess}, + {"ALLOW_STREAM_OUTPUT", RootFlags::AllowStreamOutput}, + {"LOCAL_ROOT_SIGNATURE", RootFlags::LocalRootSignature}, + {"DENY_AMPLIFICATION_SHADER_ROOT_ACCESS", + RootFlags::DenyAmplificationShaderRootAccess}, + {"DENY_MESH_SHADER_ROOT_ACCESS", RootFlags::DenyMeshShaderRootAccess}, + {"CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED", + RootFlags::CBVSRVUAVHeapDirectlyIndexed}, + {"SAMPLER_HEAP_DIRECTLY_INDEXED", RootFlags::SamplerHeapDirectlyIndexed}, + {"AllowLowTierReservedHwCbLimit", + RootFlags::AllowLowTierReservedHwCbLimit}, + }; + + return ParseEnum<RootFlags>(Mapping, Flag); +} + +bool Parser::ParseVisibility(ShaderVisibility &Visibility) { + SmallVector<std::pair<StringLiteral, ShaderVisibility>> Mapping = { + {"SHADER_VISIBILITY_ALL", ShaderVisibility::All}, + {"SHADER_VISIBILITY_VERTEX", ShaderVisibility::Vertex}, + {"SHADER_VISIBILITY_HULL", ShaderVisibility::Hull}, + {"SHADER_VISIBILITY_DOMAIN", ShaderVisibility::Domain}, + {"SHADER_VISIBILITY_GEOMETRY", ShaderVisibility::Geometry}, + {"SHADER_VISIBILITY_PIXEL", ShaderVisibility::Pixel}, + {"SHADER_VISIBILITY_AMPLIFICATION", ShaderVisibility::Amplification}, + {"SHADER_VISIBILITY_MESH", ShaderVisibility::Mesh}, + }; + + return ParseEnum<ShaderVisibility>(Mapping, Visibility); +} + +bool Parser::ParseRootElement() { + // Define different ParserMethods to use StringSwitch for dispatch + enum class ParserMethod { + ReportError, + ParseRootFlags, + ParseRootParameter, + }; + + // Retreive which method should be used + auto Method = llvm::StringSwitch<ParserMethod>(Token) + .Case("RootFlags", ParserMethod::ParseRootFlags) + .Case("RootConstants", ParserMethod::ParseRootParameter) + .Case("CBV", ParserMethod::ParseRootParameter) + .Case("SRV", ParserMethod::ParseRootParameter) + .Case("UAV", ParserMethod::ParseRootParameter) + .Default(ParserMethod::ReportError); + + // Dispatch on the correct method + switch (Method) { + case ParserMethod::ReportError: + return ReportError(); + case ParserMethod::ParseRootFlags: + return ParseRootFlags(); + case ParserMethod::ParseRootParameter: + return ParseRootParameter(); + } +} + +// Parser entry point function +bool Parser::Parse() { + StringLiteral Prefix = ""; + while (!Buffer.empty()) { + // Trim expected comma when more than 1 root element + if (!Buffer.consume_front(Prefix)) + return ReportError(); + Prefix = ","; + + // Remove any whitespace + Buffer = Buffer.drop_while(isspace); + + // Retrieve the root element identifier + auto Split = Buffer.split('('); + Token = Split.first; + Buffer = Split.second; + + // Dispatch to the applicable root element parser + if (ParseRootElement()) + return true; + + // Then we can clean up the remaining ")" + if (!Buffer.consume_front(")")) + return ReportError(); + } + + // All input has been correctly parsed + return false; +} + +} // namespace root_signature +} // namespace hlsl +} // namespace llvm diff --git a/clang/unittests/Sema/CMakeLists.txt b/clang/unittests/Sema/CMakeLists.txt index 7ded562e8edfa5..f382f8b1235306 100644 --- a/clang/unittests/Sema/CMakeLists.txt +++ b/clang/unittests/Sema/CMakeLists.txt @@ -7,6 +7,7 @@ add_clang_unittest(SemaTests ExternalSemaSourceTest.cpp CodeCompleteTest.cpp GslOwnerPointerInference.cpp + ParseHLSLRootSignatureTest.cpp SemaLookupTest.cpp SemaNoloadLookupTest.cpp ) diff --git a/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp new file mode 100644 index 00000000000000..eefe20293732ad --- /dev/null +++ b/clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp @@ -0,0 +1,156 @@ +//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/Sema/ParseHLSLRootSignature.h" +#include "gtest/gtest.h" + +using namespace llvm::hlsl::root_signature; + +namespace { + +TEST(ParseHLSLRootSignature, EmptyRootFlags) { + llvm::StringRef RootFlagString = " RootFlags()"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, RootFlagsNone) { + llvm::StringRef RootFlagString = " RootFlags(0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::None, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, ValidRootFlags) { + // Test that the flags are all captured and that they are case insensitive + llvm::StringRef RootFlagString = " RootFlags( " + " ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT" + "| deny_vertex_shader_root_access" + "| DENY_HULL_SHADER_ROOT_ACCESS" + "| deny_domain_shader_root_access" + "| DENY_GEOMETRY_SHADER_ROOT_ACCESS" + "| deny_pixel_shader_root_access" + "| ALLOW_STREAM_OUTPUT" + "| LOCAL_ROOT_SIGNATURE" + "| deny_amplification_shader_root_access" + "| DENY_MESH_SHADER_ROOT_ACCESS" + "| cbv_srv_uav_heap_directly_indexed" + "| SAMPLER_HEAP_DIRECTLY_INDEXED" + "| AllowLowTierReservedHwCbLimit )"; + + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags); +} + +TEST(ParseHLSLRootSignature, MandatoryRootConstant) { + llvm::StringRef RootFlagString = "RootConstants(num32BitConstants = 4, b42)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, OptionalRootConstant) { + llvm::StringRef RootFlagString = + "RootConstants(num32BitConstants = 4, b42, space = 4, visibility = " + "SHADER_VISIBILITY_DOMAIN)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(RootFlagString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::Constants, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)42, Parameter.Register.Number); + ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants); + ASSERT_EQ((uint32_t)4, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Domain, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, DefaultRootCBV) { + llvm::StringRef ViewsString = "CBV(b0)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)0, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::None, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootCBV) { + llvm::StringRef ViewsString = "CBV(b982374, space = 1, flags = DATA_STATIC)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::CBV, Parameter.Type); + ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)982374, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStatic, Parameter.Flags); + ASSERT_EQ((uint32_t)1, Parameter.Space); + ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootSRV) { + llvm::StringRef ViewsString = "SRV(t3, visibility = SHADER_VISIBILITY_MESH, " + "flags = Data_Static_While_Set_At_Execute)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(RootType::SRV, Parameter.Type); + ASSERT_EQ(RegisterType::TReg, Parameter.Register.ViewType); + ASSERT_EQ((uint32_t)3, Parameter.Register.Number); + ASSERT_EQ(RootDescriptorFlags::DataStaticWhileSetAtExecute, Parameter.Flags); + ASSERT_EQ((uint32_t)0, Parameter.Space); + ASSERT_EQ(ShaderVisibility::Mesh, Parameter.Visibility); +} + +TEST(ParseHLSLRootSignature, SampleRootUAV) { + llvm::StringRef ViewsString = "UAV(u0, flags = DATA_VOLATILE)"; + llvm::SmallVector<RootElement> RootElements; + Parser Parser(ViewsString, &RootElements); + ASSERT_FALSE(Parser.Parse()); + ASSERT_EQ(RootElements.size(), (unsigned long)1); + + RootParameter Parameter = RootElements[0].Parameter; + ASSERT_EQ(Roo... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/121803 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits