Skip to content
5 changes: 4 additions & 1 deletion clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1830,8 +1830,11 @@ def err_hlsl_virtual_function
def err_hlsl_virtual_inheritance
: Error<"virtual inheritance is unsupported in HLSL">;

// HLSL Root Siganture diagnostic messages
// HLSL Root Signature Parser Diagnostics
def err_hlsl_unexpected_end_of_params
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;

} // end of Parser diagnostics
42 changes: 32 additions & 10 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,31 @@ class RootSignatureParser {
private:
DiagnosticsEngine &getDiags() { return PP.getDiagnostics(); }

// All private Parse.* methods follow a similar pattern:
// All private parse.* methods follow a similar pattern:
// - Each method will start with an assert to denote what the CurToken is
// expected to be and will parse from that token forward
//
// - Therefore, it is the callers responsibility to ensure that you are
// at the correct CurToken. This should be done with the pattern of:
//
// if (TryConsumeExpectedToken(RootSignatureToken::Kind))
// if (Parse.*())
// return true;
// if (tryConsumeExpectedToken(RootSignatureToken::Kind)) {
// auto ParsedObject = parse.*();
// if (!ParsedObject.has_value())
// return std::nullopt;
// ...
// }
//
// or,
//
// if (ConsumeExpectedToken(RootSignatureToken::Kind, ...))
// return true;
// if (Parse.*())
// return true;
// if (consumeExpectedToken(RootSignatureToken::Kind, ...))
// return std::nullopt;
// auto ParsedObject = parse.*();
// if (!ParsedObject.has_value())
// return std::nullopt;
// ...
//
// - All methods return true if a parsing error is encountered. It is the
// callers responsibility to propogate this error up, or deal with it
// - All methods return std::nullopt if a parsing error is encountered. It
// is the callers responsibility to propogate this error up, or deal with it
// otherwise
//
// - An error will be raised if the proceeding tokens are not what is
Expand All @@ -69,6 +74,23 @@ class RootSignatureParser {
bool parseDescriptorTable();
bool parseDescriptorTableClause();

/// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
/// order and only exactly once. `ParsedClauseParams` denotes the current
/// state of parsed params
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Register;
std::optional<uint32_t> Space;
};
std::optional<ParsedClauseParams>
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);

std::optional<uint32_t> parseUIntParam();
std::optional<llvm::hlsl::rootsig::Register> parseRegister();

/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
/// 32-bit integer
std::optional<uint32_t> handleUIntLiteral();

/// Invoke the Lexer to consume a token and update CurToken with the result
void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }

Expand Down
164 changes: 149 additions & 15 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "clang/Parse/ParseHLSLRootSignature.h"

#include "clang/Lex/LiteralSupport.h"

#include "llvm/Support/raw_ostream.h"

using namespace llvm::hlsl::rootsig;
Expand Down Expand Up @@ -41,12 +43,11 @@ bool RootSignatureParser::parse() {
break;
}

if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
<< /*expected=*/TokenKind::end_of_stream
<< /*param of=*/TokenKind::kw_RootSignature;
if (consumeExpectedToken(TokenKind::end_of_stream,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootSignature))
return true;
}

return false;
}

Expand All @@ -72,12 +73,10 @@ bool RootSignatureParser::parseDescriptorTable() {
break;
}

if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
<< /*expected=*/TokenKind::pu_r_paren
<< /*param of=*/TokenKind::kw_DescriptorTable;
if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_DescriptorTable))
return true;
}

Elements.push_back(Table);
return false;
Expand All @@ -90,36 +89,170 @@ bool RootSignatureParser::parseDescriptorTableClause() {
CurToken.TokKind == TokenKind::kw_Sampler) &&
"Expects to only be invoked starting at given keyword");

TokenKind ParamKind = CurToken.TokKind;

if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
CurToken.TokKind))
return true;

DescriptorTableClause Clause;
switch (CurToken.TokKind) {
TokenKind ExpectedReg;
switch (ParamKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::kw_CBV:
Clause.Type = ClauseType::CBuffer;
ExpectedReg = TokenKind::bReg;
break;
case TokenKind::kw_SRV:
Clause.Type = ClauseType::SRV;
ExpectedReg = TokenKind::tReg;
break;
case TokenKind::kw_UAV:
Clause.Type = ClauseType::UAV;
ExpectedReg = TokenKind::uReg;
break;
case TokenKind::kw_Sampler:
Clause.Type = ClauseType::Sampler;
ExpectedReg = TokenKind::sReg;
break;
}

if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
CurToken.TokKind))
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
if (!Params.has_value())
return true;

if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
CurToken.TokKind))
// Check mandatory parameters were provided
if (!Params->Register.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
<< ExpectedReg;
return true;
}

Clause.Register = Params->Register.value();

// Fill in optional values
if (Params->Space.has_value())
Clause.Space = Params->Space.value();

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/ParamKind))
return true;

Elements.push_back(Clause);
return false;
}

std::optional<RootSignatureParser::ParsedClauseParams>
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
assert(CurToken.TokKind == TokenKind::pu_l_paren &&
"Expects to only be invoked starting at given token");

// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
// order and only exactly once. Parse through as many arguments as possible
// reporting an error if a duplicate is seen.
ParsedClauseParams Params;
do {
// ( `b` | `t` | `u` | `s`) POS_INT
if (tryConsumeExpectedToken(RegType)) {
if (Params.Register.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}
auto Reg = parseRegister();
if (!Reg.has_value())
return std::nullopt;
Params.Register = Reg;
}

// `space` `=` POS_INT
if (tryConsumeExpectedToken(TokenKind::kw_space)) {
if (Params.Space.has_value()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
<< CurToken.TokKind;
return std::nullopt;
}

if (consumeExpectedToken(TokenKind::pu_equal))
return std::nullopt;

auto Space = parseUIntParam();
if (!Space.has_value())
return std::nullopt;
Params.Space = Space;
}
} while (tryConsumeExpectedToken(TokenKind::pu_comma));

return Params;
}

std::optional<uint32_t> RootSignatureParser::parseUIntParam() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
"Expects to only be invoked starting at given keyword");
tryConsumeExpectedToken(TokenKind::pu_plus);
if (consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
CurToken.TokKind))
return std::nullopt;
return handleUIntLiteral();
}

std::optional<Register> RootSignatureParser::parseRegister() {
assert((CurToken.TokKind == TokenKind::bReg ||
CurToken.TokKind == TokenKind::tReg ||
CurToken.TokKind == TokenKind::uReg ||
CurToken.TokKind == TokenKind::sReg) &&
"Expects to only be invoked starting at given keyword");

Register Register;
switch (CurToken.TokKind) {
default:
llvm_unreachable("Switch for consumed token was not provided");
case TokenKind::bReg:
Register.ViewType = RegisterType::BReg;
break;
case TokenKind::tReg:
Register.ViewType = RegisterType::TReg;
break;
case TokenKind::uReg:
Register.ViewType = RegisterType::UReg;
break;
case TokenKind::sReg:
Register.ViewType = RegisterType::SReg;
break;
}

auto Number = handleUIntLiteral();
if (!Number.has_value())
return std::nullopt; // propogate NumericLiteralParser error

Register.Number = *Number;
return Register;
}

std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
// Parse the numeric value and do semantic checks on its specification
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
PP.getSourceManager(), PP.getLangOpts(),
PP.getTargetInfo(), PP.getDiagnostics());
if (Literal.hadError)
return true; // Error has already been reported so just return

assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");

llvm::APSInt Val = llvm::APSInt(32, false);
if (Literal.GetIntegerValue(Val)) {
// Report that the value has overflowed
PP.getDiagnostics().Report(CurToken.TokLoc,
diag::err_hlsl_number_literal_overflow)
<< 0 << CurToken.NumSpelling;
return std::nullopt;
}

return Val.getExtValue();
}

bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
return peekExpectedToken(ArrayRef{Expected});
}
Expand All @@ -141,6 +274,7 @@ bool RootSignatureParser::consumeExpectedToken(TokenKind Expected,
case diag::err_expected:
DB << Expected;
break;
case diag::err_hlsl_unexpected_end_of_params:
case diag::err_expected_either:
case diag::err_expected_after:
DB << Expected << Context;
Expand Down
Loading
Loading