Skip to content

Commit 10b0aab

Browse files
authored
Add infrastructure for translating ExecutionModeId (#2297)
This functionality was added in SPIR-V 1.2 and allows using an <id> to set the execution modes SubgroupsPerWorkgroupId, LocalSizeId, and LocalSizeHintI, and others.
1 parent 4dfbc85 commit 10b0aab

File tree

4 files changed

+68
-38
lines changed

4 files changed

+68
-38
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5657,19 +5657,20 @@ bool LLVMToSPIRVBase::transExecutionMode() {
56575657
auto AddSingleArgExecutionMode = [&](ExecutionMode EMode) {
56585658
uint32_t Arg = ~0u;
56595659
N.get(Arg);
5660-
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(BF, EMode, Arg)));
5660+
BF->addExecutionMode(
5661+
BM->add(new SPIRVExecutionMode(OpExecutionMode, BF, EMode, Arg)));
56615662
};
56625663

56635664
switch (EMode) {
56645665
case spv::ExecutionModeContractionOff:
5665-
BF->addExecutionMode(BM->add(
5666-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5666+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5667+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
56675668
break;
56685669
case spv::ExecutionModeInitializer:
56695670
case spv::ExecutionModeFinalizer:
56705671
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_1)) {
5671-
BF->addExecutionMode(BM->add(
5672-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5672+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5673+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
56735674
} else {
56745675
getErrorLog().checkError(false, SPIRVEC_Requires1_1,
56755676
"Initializer/Finalizer Execution Mode");
@@ -5681,15 +5682,16 @@ bool LLVMToSPIRVBase::transExecutionMode() {
56815682
unsigned X = 0, Y = 0, Z = 0;
56825683
N.get(X).get(Y).get(Z);
56835684
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5684-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
5685+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
56855686
} break;
56865687
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
56875688
if (BM->isAllowedToUseExtension(
56885689
ExtensionID::SPV_INTEL_kernel_attributes)) {
56895690
unsigned X = 0, Y = 0, Z = 0;
56905691
N.get(X).get(Y).get(Z);
56915692
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5692-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
5693+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
5694+
Z)));
56935695
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
56945696
BM->addCapability(CapabilityKernelAttributesINTEL);
56955697
}
@@ -5698,8 +5700,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
56985700
if (!BM->isAllowedToUseExtension(
56995701
ExtensionID::SPV_INTEL_kernel_attributes))
57005702
break;
5701-
BF->addExecutionMode(BM->add(
5702-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5703+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5704+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
57035705
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
57045706
BM->addCapability(CapabilityKernelAttributesINTEL);
57055707
} break;
@@ -5743,7 +5745,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
57435745
unsigned NBarrierCnt = 0;
57445746
N.get(NBarrierCnt);
57455747
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5746-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt)));
5748+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
5749+
NBarrierCnt)));
57475750
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
57485751
BM->addCapability(CapabilityVectorComputeINTEL);
57495752
} break;
@@ -5773,8 +5776,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
57735776
} break;
57745777
case spv::internal::ExecutionModeFastCompositeKernelINTEL: {
57755778
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
5776-
BF->addExecutionMode(BM->add(
5777-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5779+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5780+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
57785781
} break;
57795782
default:
57805783
llvm_unreachable("invalid execution mode");
@@ -5819,8 +5822,8 @@ void LLVMToSPIRVBase::transFPContract() {
58195822
}
58205823

58215824
if (DisableContraction) {
5822-
BF->addExecutionMode(BF->getModule()->add(
5823-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
5825+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
5826+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
58245827
}
58255828
}
58265829
}

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
612612
SPIRVExecutionModelKind TheExecModel,
613613
SPIRVId TheId, const std::string &TheName,
614614
std::vector<SPIRVId> Variables)
615-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
615+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
616616
getSizeInWords(TheName) + Variables.size() + 3),
617617
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
618618

@@ -681,7 +681,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
681681
}
682682

683683
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
684-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
684+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
685+
Str(TheStr) {}
685686

686687
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
687688

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -521,23 +521,24 @@ class SPIRVAnnotationGeneric : public SPIRVEntryNoIdGeneric {
521521
SPIRVId Target;
522522
};
523523

524-
template <Op OC> class SPIRVAnnotation : public SPIRVAnnotationGeneric {
524+
class SPIRVAnnotation : public SPIRVAnnotationGeneric {
525525
public:
526526
// Complete constructor
527-
SPIRVAnnotation(const SPIRVEntry *TheTarget, unsigned TheWordCount)
527+
SPIRVAnnotation(Op OC, const SPIRVEntry *TheTarget, unsigned TheWordCount)
528528
: SPIRVAnnotationGeneric(TheTarget->getModule(), TheWordCount, OC,
529529
TheTarget->getId()) {}
530-
// Incomplete constructor
531-
SPIRVAnnotation() : SPIRVAnnotationGeneric(OC) {}
530+
// Incomplete constructors
531+
SPIRVAnnotation(Op OC) : SPIRVAnnotationGeneric(OC) {}
532+
SPIRVAnnotation() : SPIRVAnnotationGeneric(OpNop) {}
532533
};
533534

534-
class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
535+
class SPIRVEntryPoint : public SPIRVAnnotation {
535536
public:
536537
static const SPIRVWord FixedWC = 4;
537538
SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind,
538539
SPIRVId TheId, const std::string &TheName,
539540
std::vector<SPIRVId> Variables);
540-
SPIRVEntryPoint() {}
541+
SPIRVEntryPoint() : SPIRVAnnotation(OpEntryPoint) {}
541542

542543
_SPIRV_DCL_ENCDEC
543544
protected:
@@ -548,12 +549,12 @@ class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
548549
std::vector<SPIRVId> Variables;
549550
};
550551

551-
class SPIRVName : public SPIRVAnnotation<OpName> {
552+
class SPIRVName : public SPIRVAnnotation {
552553
public:
553554
// Complete constructor
554555
SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr);
555556
// Incomplete constructor
556-
SPIRVName() {}
557+
SPIRVName() : SPIRVAnnotation(OpName) {}
557558

558559
protected:
559560
_SPIRV_DCL_ENCDEC
@@ -562,18 +563,18 @@ class SPIRVName : public SPIRVAnnotation<OpName> {
562563
std::string Str;
563564
};
564565

565-
class SPIRVMemberName : public SPIRVAnnotation<OpName> {
566+
class SPIRVMemberName : public SPIRVAnnotation {
566567
public:
567568
static const SPIRVWord FixedWC = 3;
568569
// Complete constructor
569570
SPIRVMemberName(const SPIRVEntry *TheTarget, SPIRVWord TheMemberNumber,
570571
const std::string &TheStr)
571-
: SPIRVAnnotation(TheTarget, FixedWC + getSizeInWords(TheStr)),
572+
: SPIRVAnnotation(OpName, TheTarget, FixedWC + getSizeInWords(TheStr)),
572573
MemberNumber(TheMemberNumber), Str(TheStr) {
573574
validate();
574575
}
575576
// Incomplete constructor
576-
SPIRVMemberName() : MemberNumber(SPIRVWORD_MAX) {}
577+
SPIRVMemberName() : SPIRVAnnotation(OpName), MemberNumber(SPIRVWORD_MAX) {}
577578

578579
protected:
579580
_SPIRV_DCL_ENCDEC
@@ -649,31 +650,33 @@ class SPIRVLine : public SPIRVEntry {
649650
SPIRVWord Column;
650651
};
651652

652-
class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
653+
class SPIRVExecutionMode : public SPIRVAnnotation {
653654
public:
654655
// Complete constructor for LocalSize, LocalSizeHint
655-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
656-
SPIRVWord X, SPIRVWord Y, SPIRVWord Z)
657-
: SPIRVAnnotation(TheTarget, 6), ExecMode(TheExecMode) {
656+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
657+
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
658+
SPIRVWord Y, SPIRVWord Z)
659+
: SPIRVAnnotation(OC, TheTarget, 6), ExecMode(TheExecMode) {
658660
WordLiterals.push_back(X);
659661
WordLiterals.push_back(Y);
660662
WordLiterals.push_back(Z);
661663
updateModuleVersion();
662664
}
663665
// Complete constructor for VecTypeHint, SubgroupSize, SubgroupsPerWorkgroup
664-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
665-
SPIRVWord Code)
666-
: SPIRVAnnotation(TheTarget, 4), ExecMode(TheExecMode) {
666+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
667+
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
668+
: SPIRVAnnotation(OC, TheTarget, 4), ExecMode(TheExecMode) {
667669
WordLiterals.push_back(Code);
668-
updateModuleVersion();
669670
}
670671
// Complete constructor for ContractionOff
671-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode)
672-
: SPIRVAnnotation(TheTarget, 3), ExecMode(TheExecMode) {
672+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
673+
SPIRVExecutionModeKind TheExecMode)
674+
: SPIRVAnnotation(OC, TheTarget, 3), ExecMode(TheExecMode) {
673675
updateModuleVersion();
674676
}
675677
// Incomplete constructor
676-
SPIRVExecutionMode() : ExecMode(ExecutionModeInvocations) {}
678+
SPIRVExecutionMode()
679+
: SPIRVAnnotation(OpExecutionMode), ExecMode(ExecutionModeInvocations) {}
677680
SPIRVExecutionModeKind getExecutionMode() const { return ExecMode; }
678681
const std::vector<SPIRVWord> &getLiterals() const { return WordLiterals; }
679682
SPIRVCapVec getRequiredCapability() const override {
@@ -699,6 +702,28 @@ class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
699702
std::vector<SPIRVWord> WordLiterals;
700703
};
701704

705+
class SPIRVExecutionModeId : public SPIRVExecutionMode {
706+
public:
707+
// Complete constructor for LocalSizeId, LocalSizeHintId
708+
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
709+
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
710+
SPIRVWord Y, SPIRVWord Z)
711+
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, X, Y, Z) {
712+
updateModuleVersion();
713+
}
714+
// Complete constructor for SubgroupsPerWorkgroupId
715+
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
716+
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
717+
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, Code) {
718+
updateModuleVersion();
719+
}
720+
// Incomplete constructor
721+
SPIRVExecutionModeId() : SPIRVExecutionMode() {}
722+
SPIRVWord getRequiredSPIRVVersion() const override {
723+
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_2);
724+
}
725+
};
726+
702727
class SPIRVComponentExecutionModes {
703728
typedef std::multimap<SPIRVExecutionModeKind, SPIRVExecutionMode *>
704729
SPIRVExecutionModeMap;

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ _SPIRV_OP(TypePipeStorage, 322)
295295
_SPIRV_OP(ConstantPipeStorage, 323)
296296
_SPIRV_OP(CreatePipeFromPipeStorage, 324)
297297
_SPIRV_OP(ModuleProcessed, 330)
298+
_SPIRV_OP(ExecutionModeId, 331)
298299
_SPIRV_OP(DecorateId, 332)
299300
_SPIRV_OP(GroupNonUniformElect, 333)
300301
_SPIRV_OP(GroupNonUniformAll, 334)

0 commit comments

Comments
 (0)