-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[HLSL][RootSignature] Add parsing of DescriptorRangeFlags #136775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) Changes
Part four of implementing #126569 Full diff: https://github.com/llvm/llvm-project/pull/136775.diff 4 Files Affected:
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Space;
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+ parseDescriptorRangeFlags();
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();
+ /// Flags may specify the value of '0' to denote that there should be no
+ /// flags set.
+ ///
+ /// Return true if the current int_literal token is '0', otherwise false
+ bool verifyZeroFlag();
+
/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.consumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
ExpectedReg = TokenKind::sReg;
break;
}
+ Clause.setDefaultFlags();
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Space.has_value())
Clause.Space = Params->Space.value();
+ if (Params->Flags.has_value())
+ Clause.Flags = Params->Flags.value();
+
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
return std::nullopt;
Params.Space = Space;
}
+
+ // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+ if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+ if (Params.Flags.has_value()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << CurToken.TokKind;
+ return std::nullopt;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_equal))
+ return std::nullopt;
+
+ auto Flags = parseDescriptorRangeFlags();
+ if (!Flags.has_value())
+ return std::nullopt;
+ Params.Flags = Flags;
+ }
+
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+ if (!Flags.has_value())
+ return Flag;
+
+ return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+ llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+ assert(CurToken.TokKind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+ return std::nullopt;
+ }
+ return DescriptorRangeFlags::None;
+ }
+
+ TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ std::optional<DescriptorRangeFlags> Flags;
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
+ case TokenKind::en_##NAME: \
+ Flags = \
+ maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+ return Flags;
+}
+
std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
return Val.getExtValue();
}
+bool RootSignatureParser::verifyZeroFlag() {
+ assert(CurToken.TokKind == TokenKind::int_literal);
+ auto X = handleUIntLiteral();
+ return X.has_value() && X.value() == 0;
+}
+
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
CBV(b0),
- SRV(space = 3, t42),
+ SRV(space = 3, t42, flags = 0),
visibility = SHADER_VISIBILITY_PIXEL,
Sampler(s987, space = +2),
- UAV(u4294967294)
+ UAV(u4294967294,
+ flags = Descriptors_Volatile | Data_Volatile
+ | Data_Static_While_Set_At_Execute | Data_Static
+ | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+ )
),
DescriptorTable()
)cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::BReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute);
Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::TReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::SReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[3];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::UReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidFlags);
Elem = Elements[4];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+ // This test will checks we can set the valid enum for Sampler descriptor
+ // range flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ RootElement Elem = Elements[0];
+ ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidSamplerFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+ // This test will check that parsing fails when a non-zero integer literal
+ // is given to flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b0, flags = 3)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_expected);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..b247ab9144280 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
// Definition of the various enumerations and flags
+enum class DescriptorRangeFlags : unsigned {
+ None = 0,
+ DescriptorsVolatile = 0x1,
+ DataVolatile = 0x2,
+ DataStaticWhileSetAtExecute = 0x4,
+ DataStatic = 0x8,
+ DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+ ValidFlags = 0x1000f,
+ ValidSamplerFlags = DescriptorsVolatile,
+};
+
enum class ShaderVisibility {
All = 0,
Vertex = 1,
@@ -55,6 +66,24 @@ struct DescriptorTableClause {
ClauseType Type;
Register Reg;
uint32_t Space = 0;
+ DescriptorRangeFlags Flags;
+
+ void setDefaultFlags() {
+ switch (Type) {
+ case ClauseType::CBuffer:
+ Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ClauseType::SRV:
+ Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ClauseType::UAV:
+ Flags = DescriptorRangeFlags::DataVolatile;
+ break;
+ case ClauseType::Sampler:
+ Flags = DescriptorRangeFlags::None;
+ break;
+ }
+ }
};
// Models RootElement : DescriptorTable | DescriptorTableClause
|
@llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) Changes
Part four of implementing #126569 Full diff: https://github.com/llvm/llvm-project/pull/136775.diff 4 Files Affected:
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index d639ca91c002f..d2e8f4dbcfc0c 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -81,6 +81,7 @@ class RootSignatureParser {
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Space;
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+ std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+ parseDescriptorRangeFlags();
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();
+ /// Flags may specify the value of '0' to denote that there should be no
+ /// flags set.
+ ///
+ /// Return true if the current int_literal token is '0', otherwise false
+ bool verifyZeroFlag();
+
/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.consumeToken(); }
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 8244e91c8f89a..3b9e96017c88d 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
ExpectedReg = TokenKind::sReg;
break;
}
+ Clause.setDefaultFlags();
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Space.has_value())
Clause.Space = Params->Space.value();
+ if (Params->Flags.has_value())
+ Clause.Flags = Params->Flags.value();
+
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
return std::nullopt;
Params.Space = Space;
}
+
+ // `flags` `=` DESCRIPTOR_RANGE_FLAGS
+ if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+ if (Params.Flags.has_value()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << CurToken.TokKind;
+ return std::nullopt;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_equal))
+ return std::nullopt;
+
+ auto Flags = parseDescriptorRangeFlags();
+ if (!Flags.has_value())
+ return std::nullopt;
+ Params.Flags = Flags;
+ }
+
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
+template <typename FlagType>
+static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
+ if (!Flags.has_value())
+ return Flag;
+
+ return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
+ llvm::to_underlying(Flag));
+}
+
+std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
+RootSignatureParser::parseDescriptorRangeFlags() {
+ assert(CurToken.TokKind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
+ return std::nullopt;
+ }
+ return DescriptorRangeFlags::None;
+ }
+
+ TokenKind Expected[] = {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ std::optional<DescriptorRangeFlags> Flags;
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
+ case TokenKind::en_##NAME: \
+ Flags = \
+ maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+ return Flags;
+}
+
std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
return Val.getExtValue();
}
+bool RootSignatureParser::verifyZeroFlag() {
+ assert(CurToken.TokKind == TokenKind::int_literal);
+ auto X = handleUIntLiteral();
+ return X.has_value() && X.value() == 0;
+}
+
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 1d89567509e72..f4baf1580de61 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
const llvm::StringLiteral Source = R"cc(
DescriptorTable(
CBV(b0),
- SRV(space = 3, t42),
+ SRV(space = 3, t42, flags = 0),
visibility = SHADER_VISIBILITY_PIXEL,
Sampler(s987, space = +2),
- UAV(u4294967294)
+ UAV(u4294967294,
+ flags = Descriptors_Volatile | Data_Volatile
+ | Data_Static_While_Set_At_Execute | Data_Static
+ | Descriptors_Static_Keeping_Buffer_Bounds_Checks
+ )
),
DescriptorTable()
)cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::BReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute);
Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::TReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::SReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::None);
Elem = Elements[3];
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
RegisterType::UReg);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidFlags);
Elem = Elements[4];
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
+ // This test will checks we can set the valid enum for Sampler descriptor
+ // range flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ RootElement Elem = Elements[0];
+ ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
+ ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
+ DescriptorRangeFlags::ValidSamplerFlags);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
+ // This test will check that parsing fails when a non-zero integer literal
+ // is given to flags
+ const llvm::StringLiteral Source = R"cc(
+ DescriptorTable(
+ CBV(b0, flags = 3)
+ )
+ )cc";
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+ auto TokLoc = SourceLocation();
+
+ hlsl::RootSignatureLexer Lexer(Source, TokLoc);
+ SmallVector<RootElement> Elements;
+ hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_expected);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index d51b853942dd3..b247ab9144280 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -23,6 +23,17 @@ namespace rootsig {
// Definition of the various enumerations and flags
+enum class DescriptorRangeFlags : unsigned {
+ None = 0,
+ DescriptorsVolatile = 0x1,
+ DataVolatile = 0x2,
+ DataStaticWhileSetAtExecute = 0x4,
+ DataStatic = 0x8,
+ DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
+ ValidFlags = 0x1000f,
+ ValidSamplerFlags = DescriptorsVolatile,
+};
+
enum class ShaderVisibility {
All = 0,
Vertex = 1,
@@ -55,6 +66,24 @@ struct DescriptorTableClause {
ClauseType Type;
Register Reg;
uint32_t Space = 0;
+ DescriptorRangeFlags Flags;
+
+ void setDefaultFlags() {
+ switch (Type) {
+ case ClauseType::CBuffer:
+ Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ClauseType::SRV:
+ Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ClauseType::UAV:
+ Flags = DescriptorRangeFlags::DataVolatile;
+ break;
+ case ClauseType::Sampler:
+ Flags = DescriptorRangeFlags::None;
+ break;
+ }
+ }
};
// Models RootElement : DescriptorTable | DescriptorTableClause
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, I approved the wrong PR. I will review this though.
Tabbed into the wrong PR when making the approval
- Defines `parseDescriptorRangeFlags` to establish a pattern of how flags will be parsed - Add corresponding unit tests Part four of implementing llvm#126569
16973f4
to
eeff952
Compare
parseDescriptorRangeFlags
to establish a pattern of how flags will be parsedPart four of implementing #126569