Author: Finn Plummer
Date: 2025-04-01T14:58:30-07:00
New Revision: 676755561d5a2f074411ad289fed55c977571a32

URL: 
https://github.com/llvm/llvm-project/commit/676755561d5a2f074411ad289fed55c977571a32
DIFF: 
https://github.com/llvm/llvm-project/commit/676755561d5a2f074411ad289fed55c977571a32.diff

LOG: Reland "[HLSL][RootSignature] Implement parsing of a DescriptorTable with 
empty clauses" (#133958)

This pr relands https://github.com/llvm/llvm-project/pull/133302.

It resolves two issues:
- Linking error during build,
[here](https://github.com/llvm/llvm-project/pull/133302#issuecomment-2767259848).
There was a missing dependency for `clangLex` for the
`ParseHLSLRootSignatureTest.cpp` unit testing. This library was added to
the dependencies to resolve the error. It wasn't caught previously as
the library was transitively linked in most build environments
- Warning of unused declaration,
[here](https://github.com/llvm/llvm-project/pull/133302#issuecomment-2767091368).
There was a usability line in `LexHLSLRootSignature.h` of the form
`using TokenKind = enum RootSignatureToken::Kind` which causes this
error. The declaration is removed from the header file to be used
locally in the `.cpp` files that use it.
Notably, the original pr would also exposed `clang::hlsl::TokenKind` to
everywhere it was included, which had a name clash with
`tok::TokenKind`. This is another motivation to change to the proposed
resolution.

---------

Co-authored-by: Finn Plummer <finnplum...@microsoft.com>

Added: 
    clang/include/clang/Parse/ParseHLSLRootSignature.h
    clang/lib/Parse/ParseHLSLRootSignature.cpp
    clang/unittests/Parse/CMakeLists.txt
    clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
    llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Modified: 
    clang/include/clang/Basic/DiagnosticParseKinds.td
    clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
    clang/include/clang/Lex/LexHLSLRootSignature.h
    clang/lib/Lex/LexHLSLRootSignature.cpp
    clang/lib/Parse/CMakeLists.txt
    clang/unittests/CMakeLists.txt
    clang/unittests/Lex/LexHLSLRootSignatureTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/DiagnosticParseKinds.td 
b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 86c361b4dbcf7..2582e1e5ef0f6 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1830,4 +1830,8 @@ def err_hlsl_virtual_function
 def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
+// HLSL Root Siganture diagnostic messages
+def err_hlsl_unexpected_end_of_params
+    : Error<"expected %0 to denote end of parameters, or, another valid 
parameter of %1">;
+
 } // end of Parser diagnostics

diff  --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def 
b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
index e6df763920430..c514d3456146a 100644
--- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
+++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
@@ -14,16 +14,16 @@
 
//===----------------------------------------------------------------------===//
 
 #ifndef TOK
-#define TOK(X)
+#define TOK(X, SPELLING)
 #endif
 #ifndef PUNCTUATOR
-#define PUNCTUATOR(X,Y) TOK(pu_ ## X)
+#define PUNCTUATOR(X,Y) TOK(pu_ ## X, Y)
 #endif
 #ifndef KEYWORD
-#define KEYWORD(X) TOK(kw_ ## X)
+#define KEYWORD(X) TOK(kw_ ## X, #X)
 #endif
 #ifndef ENUM
-#define ENUM(NAME, LIT) TOK(en_ ## NAME)
+#define ENUM(NAME, LIT) TOK(en_ ## NAME, LIT)
 #endif
 
 // Defines the various types of enum
@@ -49,15 +49,15 @@
 #endif
 
 // General Tokens:
-TOK(invalid)
-TOK(end_of_stream)
-TOK(int_literal)
+TOK(invalid, "invalid identifier")
+TOK(end_of_stream, "end of stream")
+TOK(int_literal, "integer literal")
 
 // Register Tokens:
-TOK(bReg)
-TOK(tReg)
-TOK(uReg)
-TOK(sReg)
+TOK(bReg, "b register")
+TOK(tReg, "t register")
+TOK(uReg, "u register")
+TOK(sReg, "s register")
 
 // Punctuators:
 PUNCTUATOR(l_paren, '(')
@@ -69,6 +69,7 @@ PUNCTUATOR(plus,    '+')
 PUNCTUATOR(minus,   '-')
 
 // RootElement Keywords:
+KEYWORD(RootSignature) // used only for diagnostic messaging
 KEYWORD(DescriptorTable)
 
 // DescriptorTable Keywords:

diff  --git a/clang/include/clang/Lex/LexHLSLRootSignature.h 
b/clang/include/clang/Lex/LexHLSLRootSignature.h
index 21c44e0351d9e..4dc80ff546aa0 100644
--- a/clang/include/clang/Lex/LexHLSLRootSignature.h
+++ b/clang/include/clang/Lex/LexHLSLRootSignature.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_LEX_LEXHLSLROOTSIGNATURE_H
 #define LLVM_CLANG_LEX_LEXHLSLROOTSIGNATURE_H
 
+#include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/SourceLocation.h"
 
 #include "llvm/ADT/SmallVector.h"
@@ -24,11 +25,11 @@ namespace hlsl {
 
 struct RootSignatureToken {
   enum Kind {
-#define TOK(X) X,
+#define TOK(X, SPELLING) X,
 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
   };
 
-  Kind Kind = Kind::invalid;
+  Kind TokKind = Kind::invalid;
 
   // Retain the SouceLocation of the token for diagnostics
   clang::SourceLocation TokLoc;
@@ -38,10 +39,21 @@ struct RootSignatureToken {
 
   // Constructors
   RootSignatureToken(clang::SourceLocation TokLoc) : TokLoc(TokLoc) {}
-  RootSignatureToken(enum Kind Kind, clang::SourceLocation TokLoc)
-      : Kind(Kind), TokLoc(TokLoc) {}
+  RootSignatureToken(Kind TokKind, clang::SourceLocation TokLoc)
+      : TokKind(TokKind), TokLoc(TokLoc) {}
 };
-using TokenKind = enum RootSignatureToken::Kind;
+
+inline const DiagnosticBuilder &
+operator<<(const DiagnosticBuilder &DB, const RootSignatureToken::Kind Kind) {
+  switch (Kind) {
+#define TOK(X, SPELLING)                                                       
\
+  case RootSignatureToken::Kind::X:                                            
\
+    DB << SPELLING;                                                            
\
+    break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+  }
+  return DB;
+}
 
 class RootSignatureLexer {
 public:

diff  --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h 
b/clang/include/clang/Parse/ParseHLSLRootSignature.h
new file mode 100644
index 0000000000000..18cc2c6692551
--- /dev/null
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -0,0 +1,107 @@
+//===--- ParseHLSLRootSignature.h -------------------------------*- C++ 
-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the RootSignatureParser interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
+#define LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H
+
+#include "clang/Basic/DiagnosticParse.h"
+#include "clang/Lex/LexHLSLRootSignature.h"
+#include "clang/Lex/Preprocessor.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+
+namespace clang {
+namespace hlsl {
+
+class RootSignatureParser {
+public:
+  RootSignatureParser(SmallVector<llvm::hlsl::rootsig::RootElement> &Elements,
+                      RootSignatureLexer &Lexer, clang::Preprocessor &PP);
+
+  /// Consumes tokens from the Lexer and constructs the in-memory
+  /// representations of the RootElements. Tokens are consumed until an
+  /// error is encountered or the end of the buffer.
+  ///
+  /// Returns true if a parsing error is encountered.
+  bool parse();
+
+private:
+  DiagnosticsEngine &getDiags() { return PP.getDiagnostics(); }
+
+  // All private Parse.* methods follow a similar pattern:
+  //   - Each method will start with an assert to denote what the CurToken is
+  // expected to be and will parse from that token forward
+  //
+  //   - Therefore, it is the callers responsibility to ensure that you are
+  // at the correct CurToken. This should be done with the pattern of:
+  //
+  //  if (TryConsumeExpectedToken(RootSignatureToken::Kind))
+  //    if (Parse.*())
+  //      return true;
+  //
+  // or,
+  //
+  //  if (ConsumeExpectedToken(RootSignatureToken::Kind, ...))
+  //    return true;
+  //  if (Parse.*())
+  //    return true;
+  //
+  //   - All methods return true if a parsing error is encountered. It is the
+  // callers responsibility to propogate this error up, or deal with it
+  // otherwise
+  //
+  //   - An error will be raised if the proceeding tokens are not what is
+  // expected, or, there is a lexing error
+
+  /// Root Element parse methods:
+  bool parseDescriptorTable();
+  bool parseDescriptorTableClause();
+
+  /// Invoke the Lexer to consume a token and update CurToken with the result
+  void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
+
+  /// Return true if the next token one of the expected kinds
+  bool peekExpectedToken(RootSignatureToken::Kind Expected);
+  bool peekExpectedToken(ArrayRef<RootSignatureToken::Kind> AnyExpected);
+
+  /// Consumes the next token and report an error if it is not of the expected
+  /// kind.
+  ///
+  /// Returns true if there was an error reported.
+  bool consumeExpectedToken(
+      RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
+      RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+
+  /// Peek if the next token is of the expected kind and if it is then consume
+  /// it.
+  ///
+  /// Returns true if it successfully matches the expected kind and the token
+  /// was consumed.
+  bool tryConsumeExpectedToken(RootSignatureToken::Kind Expected);
+  bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);
+
+private:
+  SmallVector<llvm::hlsl::rootsig::RootElement> &Elements;
+  RootSignatureLexer &Lexer;
+
+  clang::Preprocessor &PP;
+
+  RootSignatureToken CurToken;
+};
+
+} // namespace hlsl
+} // namespace clang
+
+#endif // LLVM_CLANG_PARSE_PARSEHLSLROOTSIGNATURE_H

diff  --git a/clang/lib/Lex/LexHLSLRootSignature.cpp 
b/clang/lib/Lex/LexHLSLRootSignature.cpp
index fb4aab20c7275..b065d9855ddac 100644
--- a/clang/lib/Lex/LexHLSLRootSignature.cpp
+++ b/clang/lib/Lex/LexHLSLRootSignature.cpp
@@ -11,6 +11,8 @@
 namespace clang {
 namespace hlsl {
 
+using TokenKind = RootSignatureToken::Kind;
+
 // Lexer Definitions
 
 static bool IsNumberChar(char C) {
@@ -34,7 +36,7 @@ RootSignatureToken RootSignatureLexer::LexToken() {
   switch (C) {
 #define PUNCTUATOR(X, Y)                                                       
\
   case Y: {                                                                    
\
-    Result.Kind = TokenKind::pu_##X;                                           
\
+    Result.TokKind = TokenKind::pu_##X;                                        
\
     AdvanceBuffer();                                                           
\
     return Result;                                                             
\
   }
@@ -45,7 +47,7 @@ RootSignatureToken RootSignatureLexer::LexToken() {
 
   // Integer literal
   if (isdigit(C)) {
-    Result.Kind = TokenKind::int_literal;
+    Result.TokKind = TokenKind::int_literal;
     Result.NumSpelling = Buffer.take_while(IsNumberChar);
     AdvanceBuffer(Result.NumSpelling.size());
     return Result;
@@ -65,16 +67,16 @@ RootSignatureToken RootSignatureLexer::LexToken() {
     // Convert character to the register type.
     switch (C) {
     case 'b':
-      Result.Kind = TokenKind::bReg;
+      Result.TokKind = TokenKind::bReg;
       break;
     case 't':
-      Result.Kind = TokenKind::tReg;
+      Result.TokKind = TokenKind::tReg;
       break;
     case 'u':
-      Result.Kind = TokenKind::uReg;
+      Result.TokKind = TokenKind::uReg;
       break;
     case 's':
-      Result.Kind = TokenKind::sReg;
+      Result.TokKind = TokenKind::sReg;
       break;
     default:
       llvm_unreachable("Switch for an expected token was not provided");
@@ -100,14 +102,14 @@ RootSignatureToken RootSignatureLexer::LexToken() {
 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
 
   // Then attempt to retreive a string from it
-  Result.Kind = Switch.Default(TokenKind::invalid);
+  Result.TokKind = Switch.Default(TokenKind::invalid);
   AdvanceBuffer(TokSpelling.size());
   return Result;
 }
 
 RootSignatureToken RootSignatureLexer::ConsumeToken() {
   // If we previously peeked then just return the previous value over
-  if (NextToken && NextToken->Kind != TokenKind::end_of_stream) {
+  if (NextToken && NextToken->TokKind != TokenKind::end_of_stream) {
     RootSignatureToken Result = *NextToken;
     NextToken = std::nullopt;
     return Result;

diff  --git a/clang/lib/Parse/CMakeLists.txt b/clang/lib/Parse/CMakeLists.txt
index 22e902f7e1bc5..00fde537bb9c6 100644
--- a/clang/lib/Parse/CMakeLists.txt
+++ b/clang/lib/Parse/CMakeLists.txt
@@ -14,6 +14,7 @@ add_clang_library(clangParse
   ParseExpr.cpp
   ParseExprCXX.cpp
   ParseHLSL.cpp
+  ParseHLSLRootSignature.cpp
   ParseInit.cpp
   ParseObjc.cpp
   ParseOpenMP.cpp

diff  --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp 
b/clang/lib/Parse/ParseHLSLRootSignature.cpp
new file mode 100644
index 0000000000000..93a9689ebdf72
--- /dev/null
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -0,0 +1,168 @@
+//=== ParseHLSLRootSignature.cpp - Parse Root Signature 
-------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Parse/ParseHLSLRootSignature.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm::hlsl::rootsig;
+
+namespace clang {
+namespace hlsl {
+
+using TokenKind = RootSignatureToken::Kind;
+
+RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
+                                         RootSignatureLexer &Lexer,
+                                         Preprocessor &PP)
+    : Elements(Elements), Lexer(Lexer), PP(PP), CurToken(SourceLocation()) {}
+
+bool RootSignatureParser::parse() {
+  // Iterate as many RootElements as possible
+  while (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    // Dispatch onto parser method.
+    // We guard against the unreachable here as we just ensured that CurToken
+    // will be one of the kinds in the while condition
+    switch (CurToken.TokKind) {
+    case TokenKind::kw_DescriptorTable:
+      if (parseDescriptorTable())
+        return true;
+      break;
+    default:
+      llvm_unreachable("Switch for consumed token was not provided");
+    }
+
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
+        << /*expected=*/TokenKind::end_of_stream
+        << /*param of=*/TokenKind::kw_RootSignature;
+    return true;
+  }
+  return false;
+}
+
+bool RootSignatureParser::parseDescriptorTable() {
+  assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
+         "Expects to only be invoked starting at given keyword");
+
+  DescriptorTable Table;
+
+  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+                           CurToken.TokKind))
+    return true;
+
+  // Iterate as many Clauses as possible
+  while (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
+                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+    if (parseDescriptorTableClause())
+      return true;
+
+    Table.NumClauses++;
+
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
+        << /*expected=*/TokenKind::pu_r_paren
+        << /*param of=*/TokenKind::kw_DescriptorTable;
+    return true;
+  }
+
+  Elements.push_back(Table);
+  return false;
+}
+
+bool RootSignatureParser::parseDescriptorTableClause() {
+  assert((CurToken.TokKind == TokenKind::kw_CBV ||
+          CurToken.TokKind == TokenKind::kw_SRV ||
+          CurToken.TokKind == TokenKind::kw_UAV ||
+          CurToken.TokKind == TokenKind::kw_Sampler) &&
+         "Expects to only be invoked starting at given keyword");
+
+  DescriptorTableClause Clause;
+  switch (CurToken.TokKind) {
+  default:
+    llvm_unreachable("Switch for consumed token was not provided");
+  case TokenKind::kw_CBV:
+    Clause.Type = ClauseType::CBuffer;
+    break;
+  case TokenKind::kw_SRV:
+    Clause.Type = ClauseType::SRV;
+    break;
+  case TokenKind::kw_UAV:
+    Clause.Type = ClauseType::UAV;
+    break;
+  case TokenKind::kw_Sampler:
+    Clause.Type = ClauseType::Sampler;
+    break;
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
+                           CurToken.TokKind))
+    return true;
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
+                           CurToken.TokKind))
+    return true;
+
+  Elements.push_back(Clause);
+  return false;
+}
+
+bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
+  return peekExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::peekExpectedToken(ArrayRef<TokenKind> AnyExpected) {
+  RootSignatureToken Result = Lexer.PeekNextToken();
+  return llvm::is_contained(AnyExpected, Result.TokKind);
+}
+
+bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
+                                               unsigned DiagID,
+                                               TokenKind Context) {
+  if (tryConsumeExpectedToken(Expected))
+    return false;
+
+  // Report unexpected token kind error
+  DiagnosticBuilder DB = getDiags().Report(CurToken.TokLoc, DiagID);
+  switch (DiagID) {
+  case diag::err_expected:
+    DB << Expected;
+    break;
+  case diag::err_expected_either:
+  case diag::err_expected_after:
+    DB << Expected << Context;
+    break;
+  default:
+    break;
+  }
+  return true;
+}
+
+bool RootSignatureParser::tryConsumeExpectedToken(TokenKind Expected) {
+  return tryConsumeExpectedToken(ArrayRef{Expected});
+}
+
+bool RootSignatureParser::tryConsumeExpectedToken(
+    ArrayRef<TokenKind> AnyExpected) {
+  // If not the expected token just return
+  if (!peekExpectedToken(AnyExpected))
+    return false;
+  consumeNextToken();
+  return true;
+}
+
+} // namespace hlsl
+} // namespace clang

diff  --git a/clang/unittests/CMakeLists.txt b/clang/unittests/CMakeLists.txt
index 580533a97d700..f3823ba309420 100644
--- a/clang/unittests/CMakeLists.txt
+++ b/clang/unittests/CMakeLists.txt
@@ -49,6 +49,7 @@ endfunction()
 
 add_subdirectory(Basic)
 add_subdirectory(Lex)
+add_subdirectory(Parse)
 add_subdirectory(Driver)
 if(CLANG_ENABLE_STATIC_ANALYZER)
   add_subdirectory(Analysis)

diff  --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp 
b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
index d72a842922f98..36bd201df1287 100644
--- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
@@ -10,6 +10,7 @@
 #include "gtest/gtest.h"
 
 using namespace clang;
+using TokenKind = hlsl::RootSignatureToken::Kind;
 
 namespace {
 
@@ -20,18 +21,18 @@ class LexHLSLRootSignatureTest : public ::testing::Test {
 
   void CheckTokens(hlsl::RootSignatureLexer &Lexer,
                    SmallVector<hlsl::RootSignatureToken> &Computed,
-                   SmallVector<hlsl::TokenKind> &Expected) {
+                   SmallVector<TokenKind> &Expected) {
     for (unsigned I = 0, E = Expected.size(); I != E; ++I) {
       // Skip these to help with the macro generated test
-      if (Expected[I] == hlsl::TokenKind::invalid ||
-          Expected[I] == hlsl::TokenKind::end_of_stream)
+      if (Expected[I] == TokenKind::invalid ||
+          Expected[I] == TokenKind::end_of_stream)
         continue;
       hlsl::RootSignatureToken Result = Lexer.ConsumeToken();
-      ASSERT_EQ(Result.Kind, Expected[I]);
+      ASSERT_EQ(Result.TokKind, Expected[I]);
       Computed.push_back(Result);
     }
     hlsl::RootSignatureToken EndOfStream = Lexer.ConsumeToken();
-    ASSERT_EQ(EndOfStream.Kind, hlsl::TokenKind::end_of_stream);
+    ASSERT_EQ(EndOfStream.TokKind, TokenKind::end_of_stream);
     ASSERT_TRUE(Lexer.EndOfBuffer());
   }
 };
@@ -49,11 +50,10 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexNumbersTest) {
   hlsl::RootSignatureLexer Lexer(Source, TokLoc);
 
   SmallVector<hlsl::RootSignatureToken> Tokens;
-  SmallVector<hlsl::TokenKind> Expected = {
-      hlsl::TokenKind::pu_minus,    hlsl::TokenKind::int_literal,
-      hlsl::TokenKind::int_literal, hlsl::TokenKind::pu_plus,
-      hlsl::TokenKind::int_literal, hlsl::TokenKind::pu_plus,
-      hlsl::TokenKind::int_literal,
+  SmallVector<TokenKind> Expected = {
+      TokenKind::pu_minus,    TokenKind::int_literal, TokenKind::int_literal,
+      TokenKind::pu_plus,     TokenKind::int_literal, TokenKind::pu_plus,
+      TokenKind::int_literal,
   };
   CheckTokens(Lexer, Tokens, Expected);
 
@@ -85,6 +85,8 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
 
     (),|=+-
 
+    RootSignature
+
     DescriptorTable
 
     CBV SRV UAV Sampler
@@ -112,8 +114,8 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
   hlsl::RootSignatureLexer Lexer(Source, TokLoc);
 
   SmallVector<hlsl::RootSignatureToken> Tokens;
-  SmallVector<hlsl::TokenKind> Expected = {
-#define TOK(NAME) hlsl::TokenKind::NAME,
+  SmallVector<TokenKind> Expected = {
+#define TOK(NAME, SPELLING) TokenKind::NAME,
 #include "clang/Lex/HLSLRootSignatureTokenKinds.def"
   };
 
@@ -134,17 +136,17 @@ TEST_F(LexHLSLRootSignatureTest, 
ValidCaseInsensitiveKeywordsTest) {
   hlsl::RootSignatureLexer Lexer(Source, TokLoc);
 
   SmallVector<hlsl::RootSignatureToken> Tokens;
-  SmallVector<hlsl::TokenKind> Expected = {
-      hlsl::TokenKind::kw_DescriptorTable,
-      hlsl::TokenKind::kw_CBV,
-      hlsl::TokenKind::kw_SRV,
-      hlsl::TokenKind::kw_UAV,
-      hlsl::TokenKind::kw_Sampler,
-      hlsl::TokenKind::kw_space,
-      hlsl::TokenKind::kw_visibility,
-      hlsl::TokenKind::kw_flags,
-      hlsl::TokenKind::kw_numDescriptors,
-      hlsl::TokenKind::kw_offset,
+  SmallVector<TokenKind> Expected = {
+      TokenKind::kw_DescriptorTable,
+      TokenKind::kw_CBV,
+      TokenKind::kw_SRV,
+      TokenKind::kw_UAV,
+      TokenKind::kw_Sampler,
+      TokenKind::kw_space,
+      TokenKind::kw_visibility,
+      TokenKind::kw_flags,
+      TokenKind::kw_numDescriptors,
+      TokenKind::kw_offset,
   };
 
   CheckTokens(Lexer, Tokens, Expected);
@@ -160,26 +162,26 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexPeekTest) {
 
   // Test basic peek
   hlsl::RootSignatureToken Res = Lexer.PeekNextToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::pu_r_paren);
+  ASSERT_EQ(Res.TokKind, TokenKind::pu_r_paren);
 
   // Ensure it doesn't peek past one element
   Res = Lexer.PeekNextToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::pu_r_paren);
+  ASSERT_EQ(Res.TokKind, TokenKind::pu_r_paren);
 
   Res = Lexer.ConsumeToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::pu_r_paren);
+  ASSERT_EQ(Res.TokKind, TokenKind::pu_r_paren);
 
   // Invoke after reseting the NextToken
   Res = Lexer.PeekNextToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::int_literal);
+  ASSERT_EQ(Res.TokKind, TokenKind::int_literal);
 
   // Ensure we can still consume the second token
   Res = Lexer.ConsumeToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::int_literal);
+  ASSERT_EQ(Res.TokKind, TokenKind::int_literal);
 
   // Ensure end of stream token
   Res = Lexer.PeekNextToken();
-  ASSERT_EQ(Res.Kind, hlsl::TokenKind::end_of_stream);
+  ASSERT_EQ(Res.TokKind, TokenKind::end_of_stream);
 }
 
 } // anonymous namespace

diff  --git a/clang/unittests/Parse/CMakeLists.txt 
b/clang/unittests/Parse/CMakeLists.txt
new file mode 100644
index 0000000000000..2a31be625042e
--- /dev/null
+++ b/clang/unittests/Parse/CMakeLists.txt
@@ -0,0 +1,20 @@
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+add_clang_unittest(ParseTests
+  ParseHLSLRootSignatureTest.cpp
+  )
+clang_target_link_libraries(ParseTests
+  PRIVATE
+  clangAST
+  clangBasic
+  clangLex
+  clangParse
+  clangSema
+  )
+target_link_libraries(ParseTests
+  PRIVATE
+  LLVMTestingAnnotations
+  LLVMTestingSupport
+  clangTesting
+  )

diff  --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp 
b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
new file mode 100644
index 0000000000000..acdf455a5d6aa
--- /dev/null
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -0,0 +1,245 @@
+//=== ParseHLSLRootSignatureTest.cpp - Parse Root Signature tests 
---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/DiagnosticOptions.h"
+#include "clang/Basic/FileManager.h"
+#include "clang/Basic/LangOptions.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Basic/TargetInfo.h"
+#include "clang/Lex/HeaderSearch.h"
+#include "clang/Lex/HeaderSearchOptions.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Lex/ModuleLoader.h"
+#include "clang/Lex/Preprocessor.h"
+#include "clang/Lex/PreprocessorOptions.h"
+
+#include "clang/Lex/LexHLSLRootSignature.h"
+#include "clang/Parse/ParseHLSLRootSignature.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+using namespace llvm::hlsl::rootsig;
+
+namespace {
+
+// Diagnostic helper for helper tests
+class ExpectedDiagConsumer : public DiagnosticConsumer {
+  virtual void anchor() {}
+
+  void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
+                        const Diagnostic &Info) override {
+    if (!FirstDiag || !ExpectedDiagID.has_value()) {
+      Satisfied = false;
+      return;
+    }
+    FirstDiag = false;
+
+    Satisfied = ExpectedDiagID.value() == Info.getID();
+  }
+
+  bool FirstDiag = true;
+  bool Satisfied = false;
+  std::optional<unsigned> ExpectedDiagID;
+
+public:
+  void setNoDiag() {
+    Satisfied = true;
+    ExpectedDiagID = std::nullopt;
+  }
+
+  void setExpected(unsigned DiagID) {
+    Satisfied = false;
+    ExpectedDiagID = DiagID;
+  }
+
+  bool isSatisfied() { return Satisfied; }
+};
+
+// The test fixture.
+class ParseHLSLRootSignatureTest : public ::testing::Test {
+protected:
+  ParseHLSLRootSignatureTest()
+      : FileMgr(FileMgrOpts), DiagID(new DiagnosticIDs()),
+        Consumer(new ExpectedDiagConsumer()),
+        Diags(DiagID, new DiagnosticOptions, Consumer),
+        SourceMgr(Diags, FileMgr), TargetOpts(new TargetOptions) {
+    // This is an arbitrarily chosen target triple to create the target info.
+    TargetOpts->Triple = "dxil";
+    Target = TargetInfo::CreateTargetInfo(Diags, TargetOpts);
+  }
+
+  std::unique_ptr<Preprocessor> createPP(StringRef Source,
+                                         TrivialModuleLoader &ModLoader) {
+    std::unique_ptr<llvm::MemoryBuffer> Buf =
+        llvm::MemoryBuffer::getMemBuffer(Source);
+    SourceMgr.setMainFileID(SourceMgr.createFileID(std::move(Buf)));
+
+    HeaderSearchOptions SearchOpts;
+    HeaderSearch HeaderInfo(SearchOpts, SourceMgr, Diags, LangOpts,
+                            Target.get());
+    std::unique_ptr<Preprocessor> PP = std::make_unique<Preprocessor>(
+        std::make_shared<PreprocessorOptions>(), Diags, LangOpts, SourceMgr,
+        HeaderInfo, ModLoader,
+        /*IILookup =*/nullptr,
+        /*OwnsHeaderSearch =*/false);
+    PP->Initialize(*Target);
+    PP->EnterMainSourceFile();
+    return PP;
+  }
+
+  FileSystemOptions FileMgrOpts;
+  FileManager FileMgr;
+  IntrusiveRefCntPtr<DiagnosticIDs> DiagID;
+  ExpectedDiagConsumer *Consumer;
+  DiagnosticsEngine Diags;
+  SourceManager SourceMgr;
+  LangOptions LangOpts;
+  std::shared_ptr<TargetOptions> TargetOpts;
+  IntrusiveRefCntPtr<TargetInfo> Target;
+};
+
+// Valid Parser Tests
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
+  const llvm::StringLiteral Source = R"cc()cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+  ASSERT_EQ((int)Elements.size(), 0);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(),
+      SRV(),
+      Sampler(),
+      UAV()
+    ),
+    DescriptorTable()
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  // First Descriptor Table with 4 elements
+  RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
+
+  Elem = Elements[2];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+
+  Elem = Elements[3];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
+
+  Elem = Elements[4];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)4);
+
+  // Empty Descriptor Table
+  Elem = Elements[5];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
+  ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+// Invalid Parser Tests
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedTokenTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable()
+    space
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseInvalidTokenTest) {
+  const llvm::StringLiteral Source = R"cc(
+    notAnIdentifier
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced - invalid token
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced - end of stream
+  Consumer->setExpected(diag::err_expected_after);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+} // anonymous namespace

diff  --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h 
b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
new file mode 100644
index 0000000000000..c1b67844c747f
--- /dev/null
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -0,0 +1,44 @@
+//===- HLSLRootSignature.h - HLSL Root Signature helper objects 
-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with HLSL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+#define LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H
+
+#include "llvm/Support/DXILABI.h"
+#include <variant>
+
+namespace llvm {
+namespace hlsl {
+namespace rootsig {
+
+// Definitions of the in-memory data layout structures
+
+// Models the end of a descriptor table and stores its visibility
+struct DescriptorTable {
+  uint32_t NumClauses = 0; // The number of clauses in the table
+};
+
+// Models DTClause : CBV | SRV | UAV | Sampler, by collecting like parameters
+using ClauseType = llvm::dxil::ResourceClass;
+struct DescriptorTableClause {
+  ClauseType Type;
+};
+
+// Models RootElement : DescriptorTable | DescriptorTableClause
+using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
+
+} // namespace rootsig
+} // namespace hlsl
+} // namespace llvm
+
+#endif // LLVM_FRONTEND_HLSL_HLSLROOTSIGNATURE_H


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to