Skip to content

Commit 352ea14

Browse files
Fix the collection of entry point interfaces (#1334)
This is a patch to expand the collection of entry point interfaces. In SPIR-V 1.4 and later OpEntryPoint must list all global variables in the interface. Also fix quoted string output in SPIRV text format. Co-authored-by: Alexey Sotkin <alexey.sotkin@intel.com>
1 parent c3c3c68 commit 352ea14

17 files changed

+138
-91
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -606,9 +606,7 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
606606
BF->setFunctionControlMask(transFunctionControlMask(F));
607607
if (F->hasName())
608608
BM->setName(BF, F->getName().str());
609-
if (isKernel(F))
610-
BM->addEntryPoint(ExecutionModelKernel, BF->getId());
611-
else if (F->getLinkage() != GlobalValue::InternalLinkage)
609+
if (!isKernel(F) && F->getLinkage() != GlobalValue::InternalLinkage)
612610
BF->setLinkageType(transLinkageType(F));
613611

614612
// Translate OpenCL/SYCL buffer_location metadata if it's attached to the
@@ -3570,12 +3568,15 @@ bool LLVMToSPIRVBase::isAnyFunctionReachableFromFunction(
35703568
return false;
35713569
}
35723570

3573-
void LLVMToSPIRVBase::collectInputOutputVariables(SPIRVFunction *SF,
3574-
Function *F) {
3571+
std::vector<SPIRVId>
3572+
LLVMToSPIRVBase::collectEntryPointInterfaces(SPIRVFunction *SF, Function *F) {
3573+
std::vector<SPIRVId> Interface;
35753574
for (auto &GV : M->globals()) {
35763575
const auto AS = GV.getAddressSpace();
3577-
if (AS != SPIRAS_Input && AS != SPIRAS_Output)
3578-
continue;
3576+
SPIRVModule *BM = SF->getModule();
3577+
if (!BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4))
3578+
if (AS != SPIRAS_Input && AS != SPIRAS_Output)
3579+
continue;
35793580

35803581
std::unordered_set<const Function *> Funcs;
35813582

@@ -3587,9 +3588,15 @@ void LLVMToSPIRVBase::collectInputOutputVariables(SPIRVFunction *SF,
35873588
}
35883589

35893590
if (isAnyFunctionReachableFromFunction(F, Funcs)) {
3590-
SF->addVariable(ValueMap[&GV]);
3591+
SPIRVWord ModuleVersion = static_cast<SPIRVWord>(BM->getSPIRVVersion());
3592+
if (AS != SPIRAS_Input && AS != SPIRAS_Output &&
3593+
ModuleVersion < static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4))
3594+
BM->setMinSPIRVVersion(
3595+
static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4));
3596+
Interface.push_back(ValueMap[&GV]->getId());
35913597
}
35923598
}
3599+
return Interface;
35933600
}
35943601

35953602
void LLVMToSPIRVBase::mutateFuncArgType(
@@ -3692,10 +3699,10 @@ void LLVMToSPIRVBase::transFunction(Function *I) {
36923699
joinFPContract(I, FPContract::ENABLED);
36933700
fpContractUpdateRecursive(I, getFPContract(I));
36943701

3695-
bool IsKernelEntryPoint = isKernel(I);
3696-
3697-
if (IsKernelEntryPoint) {
3698-
collectInputOutputVariables(BF, I);
3702+
if (isKernel(I)) {
3703+
auto Interface = collectEntryPointInterfaces(BF, I);
3704+
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), I->getName().str(),
3705+
Interface);
36993706
}
37003707
}
37013708

lib/SPIRV/SPIRVWriter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ class LLVMToSPIRVBase {
216216
bool isAnyFunctionReachableFromFunction(
217217
const Function *FS,
218218
const std::unordered_set<const Function *> Funcs) const;
219-
void collectInputOutputVariables(SPIRVFunction *SF, Function *F);
219+
std::vector<SPIRVId> collectEntryPointInterfaces(SPIRVFunction *BF,
220+
Function *F);
220221
};
221222

222223
class LLVMToSPIRVPass : public PassInfoMixin<LLVMToSPIRVPass>,

lib/SPIRV/libSPIRV/SPIRVDecorate.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ class SPIRVDecorateLinkageAttr : public SPIRVDecorate {
241241
#ifdef _SPIRV_SUPPORT_TEXT_FMT
242242
if (SPIRVUseTextFormat) {
243243
Encoder << getString(Literals.cbegin(), Literals.cend() - 1);
244-
Encoder.OS << " ";
245244
Encoder << (SPIRVLinkageTypeKind)Literals.back();
246245
} else
247246
#endif

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,11 @@ void SPIRVEntryPoint::encode(spv_ostream &O) const {
541541
}
542542

543543
void SPIRVEntryPoint::decode(std::istream &I) {
544-
getDecoder(I) >> ExecModel >> Target >> Name >> Variables;
544+
getDecoder(I) >> ExecModel >> Target >> Name;
545+
Variables.resize(WordCount - FixedWC - getSizeInWords(Name) + 1);
546+
getDecoder(I) >> Variables;
545547
Module->setName(getOrCreateTarget(), Name);
546-
Module->addEntryPoint(ExecModel, Target);
548+
Module->addEntryPoint(ExecModel, Target, Name, Variables);
547549
}
548550

549551
void SPIRVExecutionMode::encode(spv_ostream &O) const {

lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ template <Op OC> class SPIRVAnnotation : public SPIRVAnnotationGeneric {
524524

525525
class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
526526
public:
527+
static const SPIRVWord FixedWC = 4;
527528
SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind,
528529
SPIRVId TheId, const std::string &TheName,
529530
std::vector<SPIRVId> Variables);

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,6 @@ class SPIRVModuleImpl : public SPIRVModule {
128128
getValueTypes(const std::vector<SPIRVId> &) const override;
129129
SPIRVMemoryModelKind getMemoryModel() const override { return MemoryModel; }
130130
SPIRVConstant *getLiteralAsConstant(unsigned Literal) override;
131-
unsigned getNumEntryPoints(SPIRVExecutionModelKind EM) const override {
132-
auto Loc = EntryPointVec.find(EM);
133-
if (Loc == EntryPointVec.end())
134-
return 0;
135-
return Loc->second.size();
136-
}
137-
SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind EM,
138-
unsigned I) const override {
139-
auto Loc = EntryPointVec.find(EM);
140-
if (Loc == EntryPointVec.end())
141-
return nullptr;
142-
assert(I < Loc->second.size());
143-
return get<SPIRVFunction>(Loc->second[I]);
144-
}
145131
unsigned getNumFunctions() const override { return FuncVec.size(); }
146132
unsigned getNumVariables() const override { return VariableVec.size(); }
147133
SourceLanguage getSourceLanguage(SPIRVWord *Ver = nullptr) const override {
@@ -215,8 +201,9 @@ class SPIRVModuleImpl : public SPIRVModule {
215201
SPIRVGroupMemberDecorate *
216202
addGroupMemberDecorate(SPIRVDecorationGroup *Group,
217203
const std::vector<SPIRVEntry *> &Targets) override;
218-
void addEntryPoint(SPIRVExecutionModelKind ExecModel,
219-
SPIRVId EntryPoint) override;
204+
void addEntryPoint(SPIRVExecutionModelKind ExecModel, SPIRVId EntryPoint,
205+
const std::string &Name,
206+
const std::vector<SPIRVId> &Variables) override;
220207
SPIRVForward *addForward(SPIRVType *Ty) override;
221208
SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) override;
222209
SPIRVFunction *addFunction(SPIRVFunction *) override;
@@ -495,11 +482,11 @@ class SPIRVModuleImpl : public SPIRVModule {
495482
typedef std::vector<SPIRVGroupDecorateGeneric *> SPIRVGroupDecVec;
496483
typedef std::vector<SPIRVAsmTargetINTEL *> SPIRVAsmTargetVector;
497484
typedef std::vector<SPIRVAsmINTEL *> SPIRVAsmVector;
485+
typedef std::vector<SPIRVEntryPoint *> SPIRVEntryPointVec;
498486
typedef std::map<SPIRVId, SPIRVExtInstSetKind> SPIRVIdToInstructionSetMap;
499487
std::map<SPIRVExtInstSetKind, SPIRVId> ExtInstSetIds;
500488
typedef std::map<SPIRVId, SPIRVExtInstSetKind> SPIRVIdToBuiltinSetMap;
501489
typedef std::map<SPIRVExecutionModelKind, SPIRVIdSet> SPIRVExecModelIdSetMap;
502-
typedef std::map<SPIRVExecutionModelKind, SPIRVIdVec> SPIRVExecModelIdVecMap;
503490
typedef std::unordered_map<std::string, SPIRVString *> SPIRVStringMap;
504491
typedef std::map<SPIRVTypeStruct *, std::vector<std::pair<unsigned, SPIRVId>>>
505492
SPIRVUnknownStructFieldMap;
@@ -525,7 +512,7 @@ class SPIRVModuleImpl : public SPIRVModule {
525512
SPIRVAsmTargetVector AsmTargetVec;
526513
SPIRVAsmVector AsmVec;
527514
SPIRVExecModelIdSetMap EntryPointSet;
528-
SPIRVExecModelIdVecMap EntryPointVec;
515+
SPIRVEntryPointVec EntryPointVec;
529516
SPIRVStringMap StrMap;
530517
SPIRVCapMap CapMap;
531518
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
@@ -1000,11 +987,14 @@ SPIRVModuleImpl::addDecorate(SPIRVDecorateGeneric *Dec) {
1000987
}
1001988

1002989
void SPIRVModuleImpl::addEntryPoint(SPIRVExecutionModelKind ExecModel,
1003-
SPIRVId EntryPoint) {
990+
SPIRVId EntryPoint, const std::string &Name,
991+
const std::vector<SPIRVId> &Variables) {
1004992
assert(isValid(ExecModel) && "Invalid execution model");
1005993
assert(EntryPoint != SPIRVID_INVALID && "Invalid entry point");
994+
auto *EP =
995+
add(new SPIRVEntryPoint(this, ExecModel, EntryPoint, Name, Variables));
996+
EntryPointVec.push_back(EP);
1006997
EntryPointSet[ExecModel].insert(EntryPoint);
1007-
EntryPointVec[ExecModel].push_back(EntryPoint);
1008998
addCapabilities(SPIRV::getCapability(ExecModel));
1009999
}
10101000

@@ -1833,14 +1823,10 @@ spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M) {
18331823

18341824
O << SPIRVMemoryModel(&M);
18351825

1836-
for (auto &I : MI.EntryPointVec)
1837-
for (auto &II : I.second)
1838-
O << SPIRVEntryPoint(&M, I.first, II, M.get<SPIRVFunction>(II)->getName(),
1839-
M.get<SPIRVFunction>(II)->getVariables());
1826+
O << MI.EntryPointVec;
18401827

18411828
for (auto &I : MI.EntryPointVec)
1842-
for (auto &II : I.second)
1843-
MI.get<SPIRVFunction>(II)->encodeExecutionModes(O);
1829+
MI.get<SPIRVFunction>(I->getTargetId())->encodeExecutionModes(O);
18441830

18451831
O << MI.StringVec;
18461832

lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,11 @@ class SPIRVModule {
133133
virtual const SPIRVCapMap &getCapability() const = 0;
134134
virtual bool hasCapability(SPIRVCapabilityKind) const = 0;
135135
virtual SPIRVExtInstSetKind getBuiltinSet(SPIRVId) const = 0;
136-
virtual SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind,
137-
unsigned) const = 0;
138136
virtual std::set<std::string> &getExtension() = 0;
139137
virtual SPIRVFunction *getFunction(unsigned) const = 0;
140138
virtual SPIRVVariable *getVariable(unsigned) const = 0;
141139
virtual SPIRVMemoryModelKind getMemoryModel() const = 0;
142140
virtual unsigned getNumFunctions() const = 0;
143-
virtual unsigned getNumEntryPoints(SPIRVExecutionModelKind) const = 0;
144141
virtual unsigned getNumVariables() const = 0;
145142
virtual SourceLanguage getSourceLanguage(SPIRVWord *) const = 0;
146143
virtual std::set<std::string> &getSourceExtension() = 0;
@@ -213,7 +210,9 @@ class SPIRVModule {
213210
const std::vector<SPIRVEntry *> &Targets) = 0;
214211
virtual SPIRVGroupDecorateGeneric *
215212
addGroupDecorateGeneric(SPIRVGroupDecorateGeneric *GDec) = 0;
216-
virtual void addEntryPoint(SPIRVExecutionModelKind, SPIRVId) = 0;
213+
virtual void addEntryPoint(SPIRVExecutionModelKind, SPIRVId,
214+
const std::string &,
215+
const std::vector<SPIRVId> &) = 0;
217216
virtual SPIRVForward *addForward(SPIRVType *Ty) = 0;
218217
virtual SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) = 0;
219218
virtual SPIRVFunction *addFunction(SPIRVFunction *) = 0;

lib/SPIRV/libSPIRV/SPIRVStream.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ const SPIRVEncoder &operator<<(const SPIRVEncoder &O, const std::string &Str) {
169169
#ifdef _SPIRV_SUPPORT_TEXT_FMT
170170
if (SPIRVUseTextFormat) {
171171
writeQuotedString(O.OS, Str);
172+
O.OS << " ";
172173
return O;
173174
}
174175
#endif

test/ExecutionMode.ll

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
; RUN: llvm-as < %s | llvm-spirv -spirv-text -o %t
22
; RUN: FileCheck < %t %s
33

4-
; check for magic number followed by version 1.1
5-
; CHECK: 119734787 65792
6-
74
; CHECK-DAG: TypeVoid [[VOID:[0-9]+]]
85

96
; CHECK-DAG: EntryPoint 6 [[WORKER:[0-9]+]] "worker"

test/copy_object.spt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
2 Capability Int64
66
2 Capability Int8
77
3 MemoryModel 2 2
8-
8 EntryPoint 6 1 "copy_object"
8+
6 EntryPoint 6 1 "copy_object"
99
3 Source 3 102000
1010
3 Name 2 "in"
1111
4 Decorate 3 BuiltIn 28

0 commit comments

Comments
 (0)