Skip to content

Commit 1f575b1

Browse files
committed
[HLSL][SPIRV] Add vk::constant_id attribute.
The vk::constant_id attribute is used to indicate that a global const variable represents a specialization constant in SPIR-V. This PR adds this attribute to clang. The documetation for the attribute is [here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants). Fixes #142448
1 parent a903271 commit 1f575b1

File tree

15 files changed

+247
-2
lines changed

15 files changed

+247
-2
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4951,6 +4951,14 @@ def HLSLWaveSize: InheritableAttr {
49514951
let Documentation = [WaveSizeDocs];
49524952
}
49534953

4954+
def HLSLVkConstantId : InheritableAttr {
4955+
let Spellings = [CXX11<"vk", "constant_id">];
4956+
let Args = [IntArgument<"Id">];
4957+
let Subjects = SubjectList<[Var]>;
4958+
let LangOpts = [HLSL];
4959+
let Documentation = [VkConstantIdDocs];
4960+
}
4961+
49544962
def RandomizeLayout : InheritableAttr {
49554963
let Spellings = [GCC<"randomize_layout">];
49564964
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8195,6 +8195,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
81958195
}];
81968196
}
81978197

8198+
def VkConstantIdDocs : Documentation {
8199+
let Category = DocCatFunction;
8200+
let Content = [{
8201+
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
8202+
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
8203+
In SPIR-V, the
8204+
variable will be replaced with an `OpSpecConstant` with the given id.
8205+
The syntax is:
8206+
8207+
.. code-block:: text
8208+
8209+
``[[vk::constant_id(<Id>)]] const T Name = <Init>``
8210+
}];
8211+
}
8212+
81988213
def RootSignatureDocs : Documentation {
81998214
let Category = DocCatFunction;
82008215
let Content = [{

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12891,6 +12891,21 @@ def err_spirv_enum_not_int : Error<
1289112891
def err_spirv_enum_not_valid : Error<
1289212892
"invalid value for %select{storage class}0 argument">;
1289312893

12894+
def err_specialization_const_lit_init
12895+
: Error<"variable with 'vk::constant_id' attribute cannot have an "
12896+
"initializer that is not a constexpr">;
12897+
def err_specialization_const_is_not_externally_visible
12898+
: Error<"variable with 'vk::constant_id' attribute must be externally "
12899+
"visible">;
12900+
def err_specialization_const_missing_initializer
12901+
: Error<
12902+
"variable with 'vk::constant_id' attribute must have an initializer">;
12903+
def err_specialization_const_missing_const
12904+
: Error<"variable with 'vk::constant_id' attribute must be const">;
12905+
def err_specialization_const_is_not_int_or_float
12906+
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
12907+
"integer, or floating point value">;
12908+
1289412909
// errors of expect.with.probability
1289512910
def err_probability_not_constant_float : Error<
1289612911
"probability argument to __builtin_expect_with_probability must be constant "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
9898
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
9999
int Min, int Max, int Preferred,
100100
int SpelledArgsCount);
101+
HLSLVkConstantIdAttr *
102+
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
101103
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
102104
llvm::Triple::EnvironmentType ShaderType);
103105
HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
122124
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
123125
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
124126
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
127+
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
125128
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
126129
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
127130
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);

clang/lib/AST/ExprConstant.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3570,6 +3570,11 @@ static bool evaluateVarDeclInit(EvalInfo &Info, const Expr *E,
35703570
if (E->isValueDependent())
35713571
return false;
35723572

3573+
// The initializer on a specialization constant is only its default value
3574+
// when it is not externally initialized. This value cannot be evaluated.
3575+
if (VD->hasAttr<HLSLVkConstantIdAttr>())
3576+
return false;
3577+
35733578
// Dig out the initializer, and use the declaration which it's attached to.
35743579
// FIXME: We should eventually check whether the variable has a reachable
35753580
// initializing declaration.

clang/lib/Basic/Attributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
212212
.Case("hlsl", AttributeCommonInfo::Scope::HLSL)
213213
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
214214
.Case("omp", AttributeCommonInfo::Scope::OMP)
215-
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
215+
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
216+
.Case("vk", AttributeCommonInfo::Scope::HLSL);
216217
}
217218

218219
unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5190,6 +5190,29 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName, llvm::Type *Ty,
51905190
if (const auto *CMA = D->getAttr<CodeModelAttr>())
51915191
GV->setCodeModel(CMA->getModel());
51925192

5193+
if (const auto *ConstIdAttr = D->getAttr<HLSLVkConstantIdAttr>()) {
5194+
const Expr *Init = D->getInit();
5195+
APValue InitValue;
5196+
bool IsConstExpr = Init->isCXX11ConstantExpr(getContext(), &InitValue);
5197+
assert(IsConstExpr &&
5198+
"HLSLVkConstantIdAttr requires a constant initializer");
5199+
llvm::SmallString<10> InitString;
5200+
switch (InitValue.getKind()) {
5201+
case APValue::ValueKind::Int:
5202+
InitValue.getInt().toString(InitString);
5203+
break;
5204+
case APValue::ValueKind::Float:
5205+
InitValue.getFloat().toString(InitString);
5206+
break;
5207+
default:
5208+
llvm_unreachable(
5209+
"HLSLVkConstantIdAttr requires an int or float initializer");
5210+
}
5211+
std::string ConstIdStr =
5212+
(llvm::Twine(ConstIdAttr->getId()) + "," + InitString).str();
5213+
GV->addAttribute("spirv-constant-id", ConstIdStr);
5214+
}
5215+
51935216
// Check if we a have a const declaration with an initializer, we may be
51945217
// able to emit it as available_externally to expose it's value to the
51955218
// optimizer.

clang/lib/Sema/SemaDecl.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28892889
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
28902890
WS->getPreferred(),
28912891
WS->getSpelledArgsCount());
2892+
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
2893+
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
28922894
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28932895
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28942896
else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,14 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
1375713759
return;
1375813760
}
1375913761

13762+
if (VDecl->hasAttr<HLSLVkConstantIdAttr>()) {
13763+
if (!Init->isCXX11ConstantExpr(Context)) {
13764+
Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
13765+
VDecl->setInvalidDecl();
13766+
return;
13767+
}
13768+
}
13769+
1376013770
// Get the decls type and save a reference for later, since
1376113771
// CheckInitializerTypes may change it.
1376213772
QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14217,6 +14227,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
1421714227
}
1421814228
}
1421914229

14230+
// HLSL variable with the `vk::constant_id` attribute must be initialized.
14231+
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
14232+
Diag(Var->getLocation(),
14233+
diag::err_specialization_const_missing_initializer);
14234+
Var->setInvalidDecl();
14235+
return;
14236+
}
14237+
1422014238
if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
1422114239
if (Var->getStorageClass() == SC_Extern) {
1422214240
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7510,6 +7510,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
75107510
case ParsedAttr::AT_HLSLWaveSize:
75117511
S.HLSL().handleWaveSizeAttr(D, AL);
75127512
break;
7513+
case ParsedAttr::AT_HLSLVkConstantId:
7514+
S.HLSL().handleVkConstantIdAttr(D, AL);
7515+
break;
75137516
case ParsedAttr::AT_HLSLSV_GroupThreadID:
75147517
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
75157518
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S,
505505
// - empty structs
506506
// - zero-sized arrays
507507
// - non-variable declarations
508+
// - SPIR-V specialization constants
508509
// The layout struct will be added to the HLSLBufferDecl declarations.
509510
void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
510511
ASTContext &AST = S.getASTContext();
@@ -520,7 +521,8 @@ void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
520521
for (Decl *D : BufDecl->buffer_decls()) {
521522
VarDecl *VD = dyn_cast<VarDecl>(D);
522523
if (!VD || VD->getStorageClass() == SC_Static ||
523-
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
524+
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared ||
525+
VD->hasAttr<HLSLVkConstantIdAttr>())
524526
continue;
525527
const Type *Ty = VD->getType()->getUnqualifiedDesugaredType();
526528
if (FieldDecl *FD =
@@ -607,6 +609,54 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
607609
return Result;
608610
}
609611

612+
HLSLVkConstantIdAttr *
613+
SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
614+
int Id) {
615+
616+
auto &TargetInfo = getASTContext().getTargetInfo();
617+
if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
618+
Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
619+
return nullptr;
620+
}
621+
622+
auto *VD = cast<VarDecl>(D);
623+
624+
if (!VD->getType()->isIntegerType() && !VD->getType()->isFloatingType()) {
625+
Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
626+
return nullptr;
627+
}
628+
629+
if (VD->getStorageClass() != StorageClass::SC_None &&
630+
VD->getStorageClass() != StorageClass::SC_Extern) {
631+
Diag(VD->getLocation(),
632+
diag::err_specialization_const_is_not_externally_visible);
633+
return nullptr;
634+
}
635+
636+
if (VD->isLocalVarDecl()) {
637+
Diag(VD->getLocation(),
638+
diag::err_specialization_const_is_not_externally_visible);
639+
return nullptr;
640+
}
641+
642+
if (!VD->getType().isConstQualified()) {
643+
Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
644+
return nullptr;
645+
}
646+
647+
if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
648+
if (CI->getId() != Id) {
649+
Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
650+
Diag(AL.getLoc(), diag::note_conflicting_attribute);
651+
}
652+
return nullptr;
653+
}
654+
655+
HLSLVkConstantIdAttr *Result =
656+
::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
657+
return Result;
658+
}
659+
610660
HLSLShaderAttr *
611661
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
612662
llvm::Triple::EnvironmentType ShaderType) {
@@ -1117,6 +1167,15 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
11171167
D->addAttr(NewAttr);
11181168
}
11191169

1170+
void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
1171+
uint32_t Id;
1172+
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
1173+
return;
1174+
HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
1175+
if (NewAttr)
1176+
D->addAttr(NewAttr);
1177+
}
1178+
11201179
bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
11211180
const auto *VT = T->getAs<VectorType>();
11221181

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
2+
3+
// CHECK: VarDecl {{.*}} specConst 'const hlsl_constant int' cinit
4+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 12
5+
// CHECK-NEXT: HLSLVkConstantIdAttr {{.*}} 10
6+
[[vk::constant_id(10)]]
7+
const int specConst = 12;
8+
9+
// CHECK: CXXRecordDecl {{.*}} implicit struct __cblayout_$Globals definition
10+
// CHECK-NOT: FieldDecl {{.*}} specConst 'int'
11+
12+
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
2+
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
3+
// RUN: -o - | FileCheck %s
4+
5+
6+
// CHECK-DAG: @_ZL3sc0 = external addrspace(12) constant i32, align 4 [[A0:#[0-9]+]]
7+
// CHECK-DAG: attributes [[A0]] = { "spirv-constant-id"="0,1" }
8+
[[vk::constant_id(0)]]
9+
const bool sc0 = true;
10+
11+
// CHECK-DAG: @_ZL3sc1 = external addrspace(12) constant i32, align 4 [[A1:#[0-9]+]]
12+
// CHECK-DAG: attributes [[A1]] = { "spirv-constant-id"="1,10" }
13+
[[vk::constant_id(1)]]
14+
const int sc1 = 10;
15+
16+
// CHECK-DAG: @_ZL3sc2 = external addrspace(12) constant i32, align 4 [[A2:#[0-9]+]]
17+
// CHECK-DAG: attributes [[A2]] = { "spirv-constant-id"="2,-20" }
18+
[[vk::constant_id(2)]]
19+
const int sc2 = 10-30;
20+
21+
// CHECK-DAG: @_ZL3sc3 = external addrspace(12) constant float, align 4 [[A3:#[0-9]+]]
22+
// CHECK-DAG: attributes [[A3]] = { "spirv-constant-id"="3,0.25" }
23+
[[vk::constant_id(3)]]
24+
const float sc3 = 0.5*0.5;
25+
26+
// CHECK-DAG: @_ZL3sc4 = external addrspace(12) constant i32, align 4 [[A4:#[0-9]+]]
27+
// CHECK-DAG: attributes [[A4]] = { "spirv-constant-id"="4,2" }
28+
enum E {
29+
A,
30+
B,
31+
C
32+
};
33+
34+
[[vk::constant_id(4)]]
35+
const E sc4 = E::C;
36+
37+
[numthreads(1,1,1)]
38+
void main() {
39+
bool b = sc0;
40+
int i = sc1;
41+
int j = sc2;
42+
float f = sc3;
43+
E e = sc4;
44+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple spirv-pc-vulkan1.3-compute -verify %s
2+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.8-compute -verify %s
3+
4+
#ifndef __spirv__
5+
// expected-warning@+2{{'constant_id' attribute ignored}}
6+
#endif
7+
[[vk::constant_id(0)]]
8+
const bool sc0 = true;
9+
10+
#ifdef __spirv__
11+
[[vk::constant_id(1)]]
12+
// expected-error@+1{{variable with 'vk::constant_id' attribute cannot have an initializer that is not a constexpr}}
13+
const bool sc1 = sc0; // error
14+
15+
[[vk::constant_id(2)]]
16+
// expected-error@+1{{variable with 'vk::constant_id' attribute must be externally visible}}
17+
static const bool sc2 = false; // error
18+
19+
[[vk::constant_id(3)]]
20+
// expected-error@+1{{variable with 'vk::constant_id' attribute must have an initializer}}
21+
const bool sc3; // error
22+
23+
[[vk::constant_id(4)]]
24+
// expected-error@+1{{variable with 'vk::constant_id' attribute must be const}}
25+
bool sc4 = false; // error
26+
27+
[[vk::constant_id(5)]]
28+
// expected-error@+1{{variable with 'vk::constant_id' attribute must be an enum, bool, integer, or floating point value}}
29+
const int2 sc5 = {0,0}; // error
30+
#endif
31+
32+
[numthreads(1,1,1)]
33+
void main() {
34+
#ifdef __spirv__
35+
[[vk::constant_id(6)]]
36+
// expected-error@+1{{variable with 'vk::constant_id' attribute must be externally visible}}
37+
const bool sc6 = false; // error
38+
#endif
39+
}

0 commit comments

Comments
 (0)