Skip to content

[HLSL][RootSignature] Add parsing of DescriptorRangeFlags #136775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Apr 22, 2025

  • Defines parseDescriptorRangeFlags to establish a pattern of how flags will be parsed
  • Add corresponding unit tests

Part four of implementing #126569

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Apr 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 22, 2025

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes
  • Defines parseDescriptorRangeFlags to establish a pattern of how flags will be parsed
  • Add corresponding unit tests

Part four of implementing #126569


Full diff: https://github.com/llvm/llvm-project/pull/136775.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+9)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+76)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+67-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+29)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> Space;
+    std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
   parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
 
   /// Parsing methods of various enums
   std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+  std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+  parseDescriptorRangeFlags();
 
   /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
   /// 32-bit integer
   std::optional<uint32_t> handleUIntLiteral();
 
+  /// Flags may specify the value of '0' to denote that there should be no
+  /// flags set.
+  ///
+  /// Return true if the current int_literal token is '0', otherwise false
+  bool verifyZeroFlag();
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.consumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
     ExpectedReg = TokenKind::sReg;
     break;
   }
+  Clause.setDefaultFlags();
 
   auto Params = parseDescriptorTableClauseParams(ExpectedReg);
   if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Space.has_value())
     Clause.Space = Params->Space.value();
 
+  if (Params->Flags.has_value())
+    Clause.Flags = Params->Flags.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
         return std::nullopt;
       Params.Space = Space;
     }
+
+    // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+    if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+      if (Params.Flags.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Flags = parseDescriptorRangeFlags();
+      if (!Flags.has_value())
+        return std::nullopt;
+      Params.Flags = Flags;
+    }
+
   } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
   return std::nullopt;
 }
 
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+  if (!Flags.has_value())
+    return Flag;
+
+  return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+                               llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+  assert(CurToken.TokKind == TokenKind::pu_equal &&
+         "Expects to only be invoked starting at given keyword");
+
+  // Handle the edge-case of '0' to specify no flags set
+  if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+    if (!verifyZeroFlag()) {
+      getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+      return std::nullopt;
+    }
+    return DescriptorRangeFlags::None;
+  }
+
+  TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+  };
+
+  std::optional<DescriptorRangeFlags> Flags;
+
+  do {
+    if (tryConsumeExpectedToken(Expected)) {
+      switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON)                              \
+  case TokenKind::en_##NAME:                                                   \
+    Flags =                                                                    \
+        maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME);  \
+    break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+      default:
+        llvm_unreachable("Switch for consumed enum token was not provided");
+      }
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+  return Flags;
+}
+
 std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   // Parse the numeric value and do semantic checks on its specification
   clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   return Val.getExtValue();
 }
 
+bool RootSignatureParser::verifyZeroFlag() {
+  assert(CurToken.TokKind == TokenKind::int_literal);
+  auto X = handleUIntLiteral();
+  return X.has_value() && X.value() == 0;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(space = 3, t42),
+      SRV(space = 3, t42, flags = 0),
       visibility = SHADER_VISIBILITY_PIXEL,
       Sampler(s987, space = +2),
-      UAV(u4294967294)
+      UAV(u4294967294,
+        flags = Descriptors_Volatile | Data_Volatile
+                      | Data_Static_While_Set_At_Execute | Data_Static
+                      | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+      )
     ),
     DescriptorTable()
   )cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::BReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::DataStaticWhileSetAtExecute);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::TReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::SReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::UReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidFlags);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+  // This test will checks we can set the valid enum for Sampler descriptor
+  // range flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidSamplerFlags);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
   // This test will checks we can handling trailing commas ','
   const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+  // This test will check that parsing fails when a non-zero integer literal
+  // is given to flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, flags = 3)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_expected);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..b247ab9144280 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
 
 // Definition of the various enumerations and flags
 
+enum class DescriptorRangeFlags : unsigned {
+  None = 0,
+  DescriptorsVolatile = 0x1,
+  DataVolatile = 0x2,
+  DataStaticWhileSetAtExecute = 0x4,
+  DataStatic = 0x8,
+  DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+  ValidFlags = 0x1000f,
+  ValidSamplerFlags = DescriptorsVolatile,
+};
+
 enum class ShaderVisibility {
   All = 0,
   Vertex = 1,
@@ -55,6 +66,24 @@ struct DescriptorTableClause {
   ClauseType Type;
   Register Reg;
   uint32_t Space = 0;
+  DescriptorRangeFlags Flags;
+
+  void setDefaultFlags() {
+    switch (Type) {
+    case ClauseType::CBuffer:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::SRV:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::UAV:
+      Flags = DescriptorRangeFlags::DataVolatile;
+      break;
+    case ClauseType::Sampler:
+      Flags = DescriptorRangeFlags::None;
+      break;
+    }
+  }
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause

@llvmbot
Copy link
Member

llvmbot commented Apr 22, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes
  • Defines parseDescriptorRangeFlags to establish a pattern of how flags will be parsed
  • Add corresponding unit tests

Part four of implementing #126569


Full diff: https://github.com/llvm/llvm-project/pull/136775.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+9)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+76)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+67-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+29)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> Space;
+    std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
   parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
 
   /// Parsing methods of various enums
   std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+  std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+  parseDescriptorRangeFlags();
 
   /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
   /// 32-bit integer
   std::optional<uint32_t> handleUIntLiteral();
 
+  /// Flags may specify the value of '0' to denote that there should be no
+  /// flags set.
+  ///
+  /// Return true if the current int_literal token is '0', otherwise false
+  bool verifyZeroFlag();
+
   /// Invoke the Lexer to consume a token and update CurToken with the result
   void consumeNextToken() { CurToken = Lexer.consumeToken(); }
 
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
     ExpectedReg = TokenKind::sReg;
     break;
   }
+  Clause.setDefaultFlags();
 
   auto Params = parseDescriptorTableClauseParams(ExpectedReg);
   if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Space.has_value())
     Clause.Space = Params->Space.value();
 
+  if (Params->Flags.has_value())
+    Clause.Flags = Params->Flags.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
         return std::nullopt;
       Params.Space = Space;
     }
+
+    // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+    if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+      if (Params.Flags.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Flags = parseDescriptorRangeFlags();
+      if (!Flags.has_value())
+        return std::nullopt;
+      Params.Flags = Flags;
+    }
+
   } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
   return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
   return std::nullopt;
 }
 
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+  if (!Flags.has_value())
+    return Flag;
+
+  return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+                               llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+  assert(CurToken.TokKind == TokenKind::pu_equal &&
+         "Expects to only be invoked starting at given keyword");
+
+  // Handle the edge-case of '0' to specify no flags set
+  if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+    if (!verifyZeroFlag()) {
+      getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+      return std::nullopt;
+    }
+    return DescriptorRangeFlags::None;
+  }
+
+  TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+  };
+
+  std::optional<DescriptorRangeFlags> Flags;
+
+  do {
+    if (tryConsumeExpectedToken(Expected)) {
+      switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON)                              \
+  case TokenKind::en_##NAME:                                                   \
+    Flags =                                                                    \
+        maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME);  \
+    break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+      default:
+        llvm_unreachable("Switch for consumed enum token was not provided");
+      }
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+  return Flags;
+}
+
 std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   // Parse the numeric value and do semantic checks on its specification
   clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
   return Val.getExtValue();
 }
 
+bool RootSignatureParser::verifyZeroFlag() {
+  assert(CurToken.TokKind == TokenKind::int_literal);
+  auto X = handleUIntLiteral();
+  return X.has_value() && X.value() == 0;
+}
+
 bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
   return peekExpectedToken(ArrayRef{Expected});
 }
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   const llvm::StringLiteral Source = R"cc(
     DescriptorTable(
       CBV(b0),
-      SRV(space = 3, t42),
+      SRV(space = 3, t42, flags = 0),
       visibility = SHADER_VISIBILITY_PIXEL,
       Sampler(s987, space = +2),
-      UAV(u4294967294)
+      UAV(u4294967294,
+        flags = Descriptors_Volatile | Data_Volatile
+                      | Data_Static_While_Set_At_Execute | Data_Static
+                      | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+      )
     ),
     DescriptorTable()
   )cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::BReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::DataStaticWhileSetAtExecute);
 
   Elem = Elements[1];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::TReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[2];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::SReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::None);
 
   Elem = Elements[3];
   ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
             RegisterType::UReg);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
   ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidFlags);
 
   Elem = Elements[4];
   ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+  // This test will checks we can set the valid enum for Sampler descriptor
+  // range flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  RootElement Elem = Elements[0];
+  ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+  ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+            DescriptorRangeFlags::ValidSamplerFlags);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
   // This test will checks we can handling trailing commas ','
   const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+  // This test will check that parsing fails when a non-zero integer literal
+  // is given to flags
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0, flags = 3)
+    )
+  )cc";
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+  auto TokLoc = SourceLocation();
+
+  hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_expected);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..b247ab9144280 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
 
 // Definition of the various enumerations and flags
 
+enum class DescriptorRangeFlags : unsigned {
+  None = 0,
+  DescriptorsVolatile = 0x1,
+  DataVolatile = 0x2,
+  DataStaticWhileSetAtExecute = 0x4,
+  DataStatic = 0x8,
+  DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+  ValidFlags = 0x1000f,
+  ValidSamplerFlags = DescriptorsVolatile,
+};
+
 enum class ShaderVisibility {
   All = 0,
   Vertex = 1,
@@ -55,6 +66,24 @@ struct DescriptorTableClause {
   ClauseType Type;
   Register Reg;
   uint32_t Space = 0;
+  DescriptorRangeFlags Flags;
+
+  void setDefaultFlags() {
+    switch (Type) {
+    case ClauseType::CBuffer:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::SRV:
+      Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+      break;
+    case ClauseType::UAV:
+      Flags = DescriptorRangeFlags::DataVolatile;
+      break;
+    case ClauseType::Sampler:
+      Flags = DescriptorRangeFlags::None;
+      break;
+    }
+  }
 };
 
 // Models RootElement : DescriptorTable | DescriptorTableClause

Icohedron
Icohedron previously approved these changes Apr 24, 2025
Copy link
Contributor

@Icohedron Icohedron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I approved the wrong PR. I will review this though.

@Icohedron Icohedron dismissed their stale review April 24, 2025 00:17

Tabbed into the wrong PR when making the approval

@inbelic inbelic changed the base branch from users/inbelic/pr-136751 to main April 24, 2025 20:35
Finn Plummer and others added 2 commits April 24, 2025 20:37
- Defines `parseDescriptorRangeFlags` to establish a pattern of how
flags will be parsed
- Add corresponding unit tests

Part four of implementing llvm#126569
@inbelic inbelic force-pushed the inbelic/rs-parse-flags branch from 16973f4 to eeff952 Compare April 24, 2025 20:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

4 participants