-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL][RootSignature] Add parsing of flags to RootDescriptor #140152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/inbelic/pr-140151
Are you sure you want to change the base?
[HLSL][RootSignature] Add parsing of flags to RootDescriptor #140152
Conversation
- defines RootDescriptorFlags in-memory representation - defines parseRootDescriptorFlags to be DXC compatible. This is why we support multiple `|` flags even validation will assert that only one flag is set... - add unit tests to demonstrate functionality
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) Changes
Final part of and resolves #126577 Full diff: https://github.com/llvm/llvm-project/pull/140152.diff 4 Files Affected:
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 436d217cec5b1..7b9168290d62a 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -93,6 +93,7 @@ class RootSignatureParser {
std::optional<llvm::hlsl::rootsig::Register> Reg;
std::optional<uint32_t> Space;
std::optional<llvm::hlsl::rootsig::ShaderVisibility> Visibility;
+ std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> Flags;
};
std::optional<ParsedRootParamParams>
parseRootParamParams(RootSignatureToken::Kind RegType);
@@ -113,6 +114,8 @@ class RootSignatureParser {
/// Parsing methods of various enums
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
+ std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
+ parseRootDescriptorFlags();
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
parseDescriptorRangeFlags();
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index edb61f29f10d7..faf261cc9b7fe 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -193,6 +193,7 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() {
ExpectedReg = TokenKind::uReg;
break;
}
+ Param.setDefaultFlags();
auto Params = parseRootParamParams(ExpectedReg);
if (!Params.has_value())
@@ -214,6 +215,9 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() {
if (Params->Visibility.has_value())
Param.Visibility = Params->Visibility.value();
+ if (Params->Flags.has_value())
+ Param.Flags = Params->Flags.value();
+
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootConstants))
@@ -475,6 +479,23 @@ RootSignatureParser::parseRootParamParams(TokenKind RegType) {
return std::nullopt;
Params.Visibility = Visibility;
}
+
+ // `flags` `=` ROOT_DESCRIPTOR_FLAGS
+ if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+ if (Params.Flags.has_value()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+ << CurToken.TokKind;
+ return std::nullopt;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_equal))
+ return std::nullopt;
+
+ auto Flags = parseRootDescriptorFlags();
+ if (!Flags.has_value())
+ return std::nullopt;
+ Params.Flags = Flags;
+ }
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
return Params;
@@ -654,6 +675,45 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}
+std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
+RootSignatureParser::parseRootDescriptorFlags() {
+ assert(CurToken.TokKind == TokenKind::pu_equal &&
+ "Expects to only be invoked starting at given keyword");
+
+ // Handle the edge-case of '0' to specify no flags set
+ if (tryConsumeExpectedToken(TokenKind::int_literal)) {
+ if (!verifyZeroFlag()) {
+ getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
+ return std::nullopt;
+ }
+ return RootDescriptorFlags::None;
+ }
+
+ TokenKind Expected[] = {
+#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ };
+
+ std::optional<RootDescriptorFlags> Flags;
+
+ do {
+ if (tryConsumeExpectedToken(Expected)) {
+ switch (CurToken.TokKind) {
+#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \
+ case TokenKind::en_##NAME: \
+ Flags = \
+ maybeOrFlag<RootDescriptorFlags>(Flags, RootDescriptorFlags::NAME); \
+ break;
+#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
+ default:
+ llvm_unreachable("Switch for consumed enum token was not provided");
+ }
+ }
+ } while (tryConsumeExpectedToken(TokenKind::pu_or));
+
+ return Flags;
+}
+
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
RootSignatureParser::parseDescriptorRangeFlags() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 02bf38dcb110f..7ed286589f8fa 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -347,8 +347,11 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
const llvm::StringLiteral Source = R"cc(
CBV(b0),
- SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY),
- UAV(visibility = SHADER_VISIBILITY_HULL, u34893247)
+ SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY,
+ flags = DATA_VOLATILE | DATA_STATIC | DATA_STATIC_WHILE_SET_AT_EXECUTE
+ ),
+ UAV(visibility = SHADER_VISIBILITY_HULL, u34893247),
+ CBV(b0, flags = 0),
)cc";
TrivialModuleLoader ModLoader;
@@ -364,7 +367,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
ASSERT_FALSE(Parser.parse());
- ASSERT_EQ(Elements.size(), 3u);
+ ASSERT_EQ(Elements.size(), 4u);
RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
@@ -372,6 +375,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u);
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All);
+ ASSERT_EQ(std::get<RootParam>(Elem).Flags,
+ RootDescriptorFlags::DataStaticWhileSetAtExecute);
Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
@@ -380,6 +385,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 42u);
ASSERT_EQ(std::get<RootParam>(Elem).Space, 4u);
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Geometry);
+ ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::ValidFlags);
Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
@@ -388,6 +394,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 34893247u);
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Hull);
+ ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::DataVolatile);
+
+ Elem = Elements[3];
+ ASSERT_EQ(std::get<RootParam>(Elem).Reg.ViewType, RegisterType::BReg);
+ ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u);
+ ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
+ ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All);
+ ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::None);
ASSERT_TRUE(Consumer->isSatisfied());
}
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 7aa55215abae3..98fa5f09429e3 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -46,6 +46,14 @@ enum class RootFlags : uint32_t {
ValidFlags = 0x00000fff
};
+enum class RootDescriptorFlags : unsigned {
+ None = 0,
+ DataVolatile = 0x2,
+ DataStaticWhileSetAtExecute = 0x4,
+ DataStatic = 0x8,
+ ValidFlags = 0xe,
+};
+
enum class DescriptorRangeFlags : unsigned {
None = 0,
DescriptorsVolatile = 0x1,
@@ -91,6 +99,23 @@ struct RootParam {
Register Reg;
uint32_t Space = 0;
ShaderVisibility Visibility = ShaderVisibility::All;
+ RootDescriptorFlags Flags;
+
+ void setDefaultFlags() {
+ assert(Type != ParamType::Sampler &&
+ "Sampler is not a valid type of ParamType");
+ switch (Type) {
+ case ParamType::CBuffer:
+ case ParamType::SRV:
+ Flags = RootDescriptorFlags::DataStaticWhileSetAtExecute;
+ break;
+ case ParamType::UAV:
+ Flags = RootDescriptorFlags::DataVolatile;
+ break;
+ case ParamType::Sampler:
+ break;
+ }
+ }
};
// Models the end of a descriptor table and stores its visibility
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Just some minor comments and questions for my own clarification
Reviewed but don't have approval permissions yet. LGTM! |
|
flags even validation will assert that only one flag is set...Final part of and resolves #126577