Skip to content

Commit cf2bf43

Browse files
Finn Plummerinbelic
Finn Plummer
authored andcommitted
[HLSL][RootSignature] Add parsing of DescriptorRangeFlags
- Defines `parseDescriptorRangeFlags` to establish a pattern of how flags will be parsed - Add corresponding unit tests Part four of implementing #126569
1 parent 3c39922 commit cf2bf43

File tree

4 files changed

+181
-2
lines changed

4 files changed

+181
-2
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

+9
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class RootSignatureParser {
8181
struct ParsedClauseParams {
8282
std::optional<llvm::hlsl::rootsig::Register> Reg;
8383
std::optional<uint32_t> Space;
84+
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags> Flags;
8485
};
8586
std::optional<ParsedClauseParams>
8687
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
@@ -91,11 +92,19 @@ class RootSignatureParser {
9192

9293
/// Parsing methods of various enums
9394
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
95+
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
96+
parseDescriptorRangeFlags();
9497

9598
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
9699
/// 32-bit integer
97100
std::optional<uint32_t> handleUIntLiteral();
98101

102+
/// Flags may specify the value of '0' to denote that there should be no
103+
/// flags set.
104+
///
105+
/// Return true if the current int_literal token is '0', otherwise false
106+
bool verifyZeroFlag();
107+
99108
/// Invoke the Lexer to consume a token and update CurToken with the result
100109
void consumeNextToken() { CurToken = Lexer.consumeToken(); }
101110

clang/lib/Parse/ParseHLSLRootSignature.cpp

+76
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ RootSignatureParser::parseDescriptorTableClause() {
129129
ExpectedReg = TokenKind::sReg;
130130
break;
131131
}
132+
Clause.setDefaultFlags();
132133

133134
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
134135
if (!Params.has_value())
@@ -147,6 +148,9 @@ RootSignatureParser::parseDescriptorTableClause() {
147148
if (Params->Space.has_value())
148149
Clause.Space = Params->Space.value();
149150

151+
if (Params->Flags.has_value())
152+
Clause.Flags = Params->Flags.value();
153+
150154
if (consumeExpectedToken(TokenKind::pu_r_paren,
151155
diag::err_hlsl_unexpected_end_of_params,
152156
/*param of=*/ParamKind))
@@ -194,6 +198,24 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
194198
return std::nullopt;
195199
Params.Space = Space;
196200
}
201+
202+
// `flags` `=` DESCRIPTOR_RANGE_FLAGS
203+
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
204+
if (Params.Flags.has_value()) {
205+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
206+
<< CurToken.TokKind;
207+
return std::nullopt;
208+
}
209+
210+
if (consumeExpectedToken(TokenKind::pu_equal))
211+
return std::nullopt;
212+
213+
auto Flags = parseDescriptorRangeFlags();
214+
if (!Flags.has_value())
215+
return std::nullopt;
216+
Params.Flags = Flags;
217+
}
218+
197219
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
198220

199221
return Params;
@@ -268,6 +290,54 @@ RootSignatureParser::parseShaderVisibility() {
268290
return std::nullopt;
269291
}
270292

293+
template <typename FlagType>
294+
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
295+
if (!Flags.has_value())
296+
return Flag;
297+
298+
return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
299+
llvm::to_underlying(Flag));
300+
}
301+
302+
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
303+
RootSignatureParser::parseDescriptorRangeFlags() {
304+
assert(CurToken.TokKind == TokenKind::pu_equal &&
305+
"Expects to only be invoked starting at given keyword");
306+
307+
// Handle the edge-case of '0' to specify no flags set
308+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
309+
if (!verifyZeroFlag()) {
310+
getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
311+
return std::nullopt;
312+
}
313+
return DescriptorRangeFlags::None;
314+
}
315+
316+
TokenKind Expected[] = {
317+
#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) TokenKind::en_##NAME,
318+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
319+
};
320+
321+
std::optional<DescriptorRangeFlags> Flags;
322+
323+
do {
324+
if (tryConsumeExpectedToken(Expected)) {
325+
switch (CurToken.TokKind) {
326+
#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \
327+
case TokenKind::en_##NAME: \
328+
Flags = \
329+
maybeOrFlag<DescriptorRangeFlags>(Flags, DescriptorRangeFlags::NAME); \
330+
break;
331+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
332+
default:
333+
llvm_unreachable("Switch for consumed enum token was not provided");
334+
}
335+
}
336+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
337+
338+
return Flags;
339+
}
340+
271341
std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
272342
// Parse the numeric value and do semantic checks on its specification
273343
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
@@ -290,6 +360,12 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
290360
return Val.getExtValue();
291361
}
292362

363+
bool RootSignatureParser::verifyZeroFlag() {
364+
assert(CurToken.TokKind == TokenKind::int_literal);
365+
auto X = handleUIntLiteral();
366+
return X.has_value() && X.value() == 0;
367+
}
368+
293369
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
294370
return peekExpectedToken(ArrayRef{Expected});
295371
}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

+67-2
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
130130
const llvm::StringLiteral Source = R"cc(
131131
DescriptorTable(
132132
CBV(b0),
133-
SRV(space = 3, t42),
133+
SRV(space = 3, t42, flags = 0),
134134
visibility = SHADER_VISIBILITY_PIXEL,
135135
Sampler(s987, space = +2),
136-
UAV(u4294967294)
136+
UAV(u4294967294,
137+
flags = Descriptors_Volatile | Data_Volatile
138+
| Data_Static_While_Set_At_Execute | Data_Static
139+
| Descriptors_Static_Keeping_Buffer_Bounds_Checks
140+
)
137141
),
138142
DescriptorTable()
139143
)cc";
@@ -159,6 +163,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
159163
RegisterType::BReg);
160164
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u);
161165
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
166+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
167+
DescriptorRangeFlags::DataStaticWhileSetAtExecute);
162168

163169
Elem = Elements[1];
164170
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -167,6 +173,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
167173
RegisterType::TReg);
168174
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u);
169175
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
176+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
177+
DescriptorRangeFlags::None);
170178

171179
Elem = Elements[2];
172180
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -175,6 +183,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
175183
RegisterType::SReg);
176184
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u);
177185
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
186+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
187+
DescriptorRangeFlags::None);
178188

179189
Elem = Elements[3];
180190
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -183,6 +193,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
183193
RegisterType::UReg);
184194
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u);
185195
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
196+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
197+
DescriptorRangeFlags::ValidFlags);
186198

187199
Elem = Elements[4];
188200
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -199,6 +211,35 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
199211
ASSERT_TRUE(Consumer->isSatisfied());
200212
}
201213

214+
TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
215+
// This test will checks we can set the valid enum for Sampler descriptor
216+
// range flags
217+
const llvm::StringLiteral Source = R"cc(
218+
DescriptorTable(Sampler(s0, flags = DESCRIPTORS_VOLATILE))
219+
)cc";
220+
221+
TrivialModuleLoader ModLoader;
222+
auto PP = createPP(Source, ModLoader);
223+
auto TokLoc = SourceLocation();
224+
225+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
226+
SmallVector<RootElement> Elements;
227+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
228+
229+
// Test no diagnostics produced
230+
Consumer->setNoDiag();
231+
232+
ASSERT_FALSE(Parser.parse());
233+
234+
RootElement Elem = Elements[0];
235+
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
236+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
237+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags,
238+
DescriptorRangeFlags::ValidSamplerFlags);
239+
240+
ASSERT_TRUE(Consumer->isSatisfied());
241+
}
242+
202243
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
203244
// This test will checks we can handling trailing commas ','
204245
const llvm::StringLiteral Source = R"cc(
@@ -383,4 +424,28 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
383424
ASSERT_TRUE(Consumer->isSatisfied());
384425
}
385426

427+
TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
428+
// This test will check that parsing fails when a non-zero integer literal
429+
// is given to flags
430+
const llvm::StringLiteral Source = R"cc(
431+
DescriptorTable(
432+
CBV(b0, flags = 3)
433+
)
434+
)cc";
435+
436+
TrivialModuleLoader ModLoader;
437+
auto PP = createPP(Source, ModLoader);
438+
auto TokLoc = SourceLocation();
439+
440+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
441+
SmallVector<RootElement> Elements;
442+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
443+
444+
// Test correct diagnostic produced
445+
Consumer->setExpected(diag::err_expected);
446+
ASSERT_TRUE(Parser.parse());
447+
448+
ASSERT_TRUE(Consumer->isSatisfied());
449+
}
450+
386451
} // anonymous namespace

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

+29
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ namespace rootsig {
2323

2424
// Definition of the various enumerations and flags
2525

26+
enum class DescriptorRangeFlags : unsigned {
27+
None = 0,
28+
DescriptorsVolatile = 0x1,
29+
DataVolatile = 0x2,
30+
DataStaticWhileSetAtExecute = 0x4,
31+
DataStatic = 0x8,
32+
DescriptorsStaticKeepingBufferBoundsChecks = 0x10000,
33+
ValidFlags = 0x1000f,
34+
ValidSamplerFlags = DescriptorsVolatile,
35+
};
36+
2637
enum class ShaderVisibility {
2738
All = 0,
2839
Vertex = 1,
@@ -55,6 +66,24 @@ struct DescriptorTableClause {
5566
ClauseType Type;
5667
Register Reg;
5768
uint32_t Space = 0;
69+
DescriptorRangeFlags Flags;
70+
71+
void setDefaultFlags() {
72+
switch (Type) {
73+
case ClauseType::CBuffer:
74+
Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
75+
break;
76+
case ClauseType::SRV:
77+
Flags = DescriptorRangeFlags::DataStaticWhileSetAtExecute;
78+
break;
79+
case ClauseType::UAV:
80+
Flags = DescriptorRangeFlags::DataVolatile;
81+
break;
82+
case ClauseType::Sampler:
83+
Flags = DescriptorRangeFlags::None;
84+
break;
85+
}
86+
}
5887
};
5988

6089
// Models RootElement : DescriptorTable | DescriptorTableClause

0 commit comments

Comments
 (0)