Skip to content

Commit 5d87ba1

Browse files
authored
[HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (#93847)
`HLSLShaderAttr::ShaderType` enum is a subset of `llvm::Triple::EnvironmentType`. We can use `llvm::Triple::EnvironmentType` directly and avoid converting one enum to another.
1 parent 221336c commit 5d87ba1

File tree

4 files changed

+69
-75
lines changed

4 files changed

+69
-75
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,37 +4470,20 @@ def HLSLShader : InheritableAttr {
44704470
let Subjects = SubjectList<[HLSLEntry]>;
44714471
let LangOpts = [HLSL];
44724472
let Args = [
4473-
EnumArgument<"Type", "ShaderType", /*is_string=*/true,
4473+
EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true,
44744474
["pixel", "vertex", "geometry", "hull", "domain", "compute",
44754475
"raygeneration", "intersection", "anyhit", "closesthit",
44764476
"miss", "callable", "mesh", "amplification"],
44774477
["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
44784478
"RayGeneration", "Intersection", "AnyHit", "ClosestHit",
4479-
"Miss", "Callable", "Mesh", "Amplification"]>
4479+
"Miss", "Callable", "Mesh", "Amplification"],
4480+
/*opt=*/0, /*fake=*/0, /*isExternalType=*/1>
44804481
];
44814482
let Documentation = [HLSLSV_ShaderTypeAttrDocs];
44824483
let AdditionalMembers =
44834484
[{
4484-
static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification;
4485-
4486-
static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) {
4487-
switch (ShaderType) {
4488-
case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel;
4489-
case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex;
4490-
case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry;
4491-
case HLSLShaderAttr::Hull: return llvm::Triple::Hull;
4492-
case HLSLShaderAttr::Domain: return llvm::Triple::Domain;
4493-
case HLSLShaderAttr::Compute: return llvm::Triple::Compute;
4494-
case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration;
4495-
case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection;
4496-
case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit;
4497-
case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit;
4498-
case HLSLShaderAttr::Miss: return llvm::Triple::Miss;
4499-
case HLSLShaderAttr::Callable: return llvm::Triple::Callable;
4500-
case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh;
4501-
case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification;
4502-
}
4503-
llvm_unreachable("unknown enumeration value");
4485+
static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) {
4486+
return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification;
45044487
}
45054488
}];
45064489
}

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class SemaHLSL : public SemaBase {
3939
const AttributeCommonInfo &AL, int X,
4040
int Y, int Z);
4141
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
42-
HLSLShaderAttr::ShaderType ShaderType);
42+
llvm::Triple::EnvironmentType ShaderType);
4343
HLSLParamModifierAttr *
4444
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
4545
HLSLParamModifierAttr::Spelling Spelling);
@@ -48,8 +48,8 @@ class SemaHLSL : public SemaBase {
4848
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
4949
const HLSLAnnotationAttr *AnnotationAttr);
5050
void DiagnoseAttrStageMismatch(
51-
const Attr *A, HLSLShaderAttr::ShaderType Stage,
52-
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
51+
const Attr *A, llvm::Triple::EnvironmentType Stage,
52+
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
5353
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
5454

5555
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
313313
assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
314314
const StringRef ShaderAttrKindStr = "hlsl.shader";
315315
Fn->addFnAttr(ShaderAttrKindStr,
316-
ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
316+
llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
317317
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
318318
const StringRef NumThreadsKindStr = "hlsl.numthreads";
319319
std::string NumThreadsStr =

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D,
146146

147147
HLSLShaderAttr *
148148
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
149-
HLSLShaderAttr::ShaderType ShaderType) {
149+
llvm::Triple::EnvironmentType ShaderType) {
150150
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
151151
if (NT->getType() != ShaderType) {
152152
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
@@ -184,25 +184,24 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
184184
if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
185185
return;
186186

187-
StringRef Env = TargetInfo.getTriple().getEnvironmentName();
188-
HLSLShaderAttr::ShaderType ShaderType;
189-
if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
187+
llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment();
188+
if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) {
190189
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
191190
// The entry point is already annotated - check that it matches the
192191
// triple.
193-
if (Shader->getType() != ShaderType) {
192+
if (Shader->getType() != Env) {
194193
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
195194
<< Shader;
196195
FD->setInvalidDecl();
197196
}
198197
} else {
199198
// Implicitly add the shader attribute if the entry function isn't
200199
// explicitly annotated.
201-
FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType,
200+
FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env,
202201
FD->getBeginLoc()));
203202
}
204203
} else {
205-
switch (TargetInfo.getTriple().getEnvironment()) {
204+
switch (Env) {
206205
case llvm::Triple::UnknownEnvironment:
207206
case llvm::Triple::Library:
208207
break;
@@ -215,38 +214,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) {
215214
void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) {
216215
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
217216
assert(ShaderAttr && "Entry point has no shader attribute");
218-
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
217+
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
219218

220219
switch (ST) {
221-
case HLSLShaderAttr::Pixel:
222-
case HLSLShaderAttr::Vertex:
223-
case HLSLShaderAttr::Geometry:
224-
case HLSLShaderAttr::Hull:
225-
case HLSLShaderAttr::Domain:
226-
case HLSLShaderAttr::RayGeneration:
227-
case HLSLShaderAttr::Intersection:
228-
case HLSLShaderAttr::AnyHit:
229-
case HLSLShaderAttr::ClosestHit:
230-
case HLSLShaderAttr::Miss:
231-
case HLSLShaderAttr::Callable:
220+
case llvm::Triple::Pixel:
221+
case llvm::Triple::Vertex:
222+
case llvm::Triple::Geometry:
223+
case llvm::Triple::Hull:
224+
case llvm::Triple::Domain:
225+
case llvm::Triple::RayGeneration:
226+
case llvm::Triple::Intersection:
227+
case llvm::Triple::AnyHit:
228+
case llvm::Triple::ClosestHit:
229+
case llvm::Triple::Miss:
230+
case llvm::Triple::Callable:
232231
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
233232
DiagnoseAttrStageMismatch(NT, ST,
234-
{HLSLShaderAttr::Compute,
235-
HLSLShaderAttr::Amplification,
236-
HLSLShaderAttr::Mesh});
233+
{llvm::Triple::Compute,
234+
llvm::Triple::Amplification,
235+
llvm::Triple::Mesh});
237236
FD->setInvalidDecl();
238237
}
239238
break;
240239

241-
case HLSLShaderAttr::Compute:
242-
case HLSLShaderAttr::Amplification:
243-
case HLSLShaderAttr::Mesh:
240+
case llvm::Triple::Compute:
241+
case llvm::Triple::Amplification:
242+
case llvm::Triple::Mesh:
244243
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
245244
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
246-
<< HLSLShaderAttr::ConvertShaderTypeToStr(ST);
245+
<< llvm::Triple::getEnvironmentTypeName(ST);
247246
FD->setInvalidDecl();
248247
}
249248
break;
249+
default:
250+
llvm_unreachable("Unhandled environment in triple");
250251
}
251252

252253
for (ParmVarDecl *Param : FD->parameters()) {
@@ -268,31 +269,31 @@ void SemaHLSL::CheckSemanticAnnotation(
268269
const HLSLAnnotationAttr *AnnotationAttr) {
269270
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
270271
assert(ShaderAttr && "Entry point has no shader attribute");
271-
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
272+
llvm::Triple::EnvironmentType ST = ShaderAttr->getType();
272273

273274
switch (AnnotationAttr->getKind()) {
274275
case attr::HLSLSV_DispatchThreadID:
275276
case attr::HLSLSV_GroupIndex:
276-
if (ST == HLSLShaderAttr::Compute)
277+
if (ST == llvm::Triple::Compute)
277278
return;
278-
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute});
279+
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
279280
break;
280281
default:
281282
llvm_unreachable("Unknown HLSLAnnotationAttr");
282283
}
283284
}
284285

285286
void SemaHLSL::DiagnoseAttrStageMismatch(
286-
const Attr *A, HLSLShaderAttr::ShaderType Stage,
287-
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
287+
const Attr *A, llvm::Triple::EnvironmentType Stage,
288+
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) {
288289
SmallVector<StringRef, 8> StageStrings;
289290
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
290-
[](HLSLShaderAttr::ShaderType ST) {
291+
[](llvm::Triple::EnvironmentType ST) {
291292
return StringRef(
292-
HLSLShaderAttr::ConvertShaderTypeToStr(ST));
293+
HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST));
293294
});
294295
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
295-
<< A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
296+
<< A << llvm::Triple::getEnvironmentTypeName(Stage)
296297
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
297298
}
298299

@@ -430,8 +431,8 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
430431
if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
431432
return;
432433

433-
HLSLShaderAttr::ShaderType ShaderType;
434-
if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
434+
llvm::Triple::EnvironmentType ShaderType;
435+
if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) {
435436
Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
436437
<< AL << Str << ArgLoc;
437438
return;
@@ -549,16 +550,22 @@ class DiagnoseHLSLAvailability
549550
//
550551
// Maps FunctionDecl to an unsigned number that represents the set of shader
551552
// environments the function has been scanned for.
552-
// Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is
553-
// defined without any assigned values, it is guaranteed to be numbered
554-
// sequentially from 0 up and we can use it to 'index' individual bits
555-
// in the set.
553+
// The llvm::Triple::EnvironmentType enum values for shader stages guaranteed
554+
// to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification
555+
// (verified by static_asserts in Triple.cpp), we can use it to index
556+
// individual bits in the set, as long as we shift the values to start with 0
557+
// by subtracting the value of llvm::Triple::Pixel first.
558+
//
556559
// The N'th bit in the set will be set if the function has been scanned
557-
// in shader environment whose ShaderType integer value equals N.
560+
// in shader environment whose llvm::Triple::EnvironmentType integer value
561+
// equals (llvm::Triple::Pixel + N).
562+
//
558563
// For example, if a function has been scanned in compute and pixel stage
559-
// environment, the value will be 0x21 (100001 binary) because
560-
// (int)HLSLShaderAttr::ShaderType::Pixel == 1 and
561-
// (int)HLSLShaderAttr::ShaderType::Compute == 5.
564+
// environment, the value will be 0x21 (100001 binary) because:
565+
//
566+
// (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0
567+
// (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5
568+
//
562569
// A FunctionDecl is mapped to 0 (or not included in the map) if it has not
563570
// been scanned in any environment.
564571
llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls;
@@ -574,12 +581,16 @@ class DiagnoseHLSLAvailability
574581
bool ReportOnlyShaderStageIssues;
575582

576583
// Helper methods for dealing with current stage context / environment
577-
void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) {
584+
void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) {
578585
static_assert(sizeof(unsigned) >= 4);
579-
assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown"
580-
581-
CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType);
582-
CurrentShaderStageBit = (1 << ShaderType);
586+
assert(HLSLShaderAttr::isValidShaderType(ShaderType));
587+
assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 &&
588+
"ShaderType is too big for this bitmap"); // 31 is reserved for
589+
// "unknown"
590+
591+
unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel;
592+
CurrentShaderEnvironment = ShaderType;
593+
CurrentShaderStageBit = (1 << bitmapIndex);
583594
}
584595

585596
void SetUnknownShaderStageContext() {

0 commit comments

Comments
 (0)