Skip to content

[SPIR-V] Apply changes related to SPIR-V 1.4 #4886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 58 additions & 53 deletions llvm-spirv/lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,7 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {

BM->setName(BF, F->getName().str());
}
if (isKernel(F))
BM->addEntryPoint(ExecutionModelKernel, BF->getId());
else if (F->getLinkage() != GlobalValue::InternalLinkage)
if (!isKernel(F) && F->getLinkage() != GlobalValue::InternalLinkage)
BF->setLinkageType(transLinkageType(F));

// Translate OpenCL/SYCL buffer_location metadata if it's attached to the
Expand Down Expand Up @@ -1394,9 +1392,12 @@ LLVMToSPIRVBase::getLoopControl(const BranchInst *Branch,
// PartialCount must not be used with the DontUnroll bit
else if (S == "llvm.loop.unroll.count" &&
!(LoopControl & LoopControlDontUnrollMask)) {
size_t I = getMDOperandAsInt(Node, 1);
ParametersToSort.emplace_back(spv::LoopControlPartialCountMask, I);
LoopControl |= spv::LoopControlPartialCountMask;
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4)) {
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_4);
size_t I = getMDOperandAsInt(Node, 1);
ParametersToSort.emplace_back(spv::LoopControlPartialCountMask, I);
LoopControl |= spv::LoopControlPartialCountMask;
}
} else if (S == "llvm.loop.ivdep.enable")
LoopControl |= spv::LoopControlDependencyInfiniteMask;
else if (S == "llvm.loop.ivdep.safelen") {
Expand Down Expand Up @@ -2446,10 +2447,10 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {

if (auto BVO = dyn_cast_or_null<OverflowingBinaryOperator>(V)) {
if (BVO->hasNoSignedWrap()) {
BV->setNoSignedWrap(true);
BV->setNoIntegerDecorationWrap<DecorationNoSignedWrap>(true);
}
if (BVO->hasNoUnsignedWrap()) {
BV->setNoUnsignedWrap(true);
BV->setNoIntegerDecorationWrap<DecorationNoUnsignedWrap>(true);
}
}

Expand Down Expand Up @@ -4200,12 +4201,15 @@ bool LLVMToSPIRVBase::isAnyFunctionReachableFromFunction(
return false;
}

void LLVMToSPIRVBase::collectInputOutputVariables(SPIRVFunction *SF,
Function *F) {
std::vector<SPIRVId>
LLVMToSPIRVBase::collectEntryPointInterfaces(SPIRVFunction *SF, Function *F) {
std::vector<SPIRVId> Interface;
for (auto &GV : M->globals()) {
const auto AS = GV.getAddressSpace();
if (AS != SPIRAS_Input && AS != SPIRAS_Output)
continue;
SPIRVModule *BM = SF->getModule();
if (!BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4))
if (AS != SPIRAS_Input && AS != SPIRAS_Output)
continue;

std::unordered_set<const Function *> Funcs;

Expand All @@ -4217,9 +4221,14 @@ void LLVMToSPIRVBase::collectInputOutputVariables(SPIRVFunction *SF,
}

if (isAnyFunctionReachableFromFunction(F, Funcs)) {
SF->addVariable(ValueMap[&GV]);
SPIRVWord ModuleVersion = static_cast<SPIRVWord>(BM->getSPIRVVersion());
if (AS != SPIRAS_Input && AS != SPIRAS_Output &&
ModuleVersion < static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4))
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_4);
Interface.push_back(ValueMap[&GV]->getId());
}
}
return Interface;
}

void LLVMToSPIRVBase::mutateFuncArgType(
Expand Down Expand Up @@ -4322,10 +4331,10 @@ void LLVMToSPIRVBase::transFunction(Function *I) {
joinFPContract(I, FPContract::ENABLED);
fpContractUpdateRecursive(I, getFPContract(I));

bool IsKernelEntryPoint = isKernel(I);

if (IsKernelEntryPoint) {
collectInputOutputVariables(BF, I);
if (isKernel(I)) {
auto Interface = collectEntryPointInterfaces(BF, I);
BM->addEntryPoint(ExecutionModelKernel, BF->getId(), I->getName().str(),
Interface);
}
}

Expand Down Expand Up @@ -4493,6 +4502,12 @@ bool LLVMToSPIRVBase::transExecutionMode() {
if (!BF)
return false;

auto AddSingleArgExecutionMode = [&](ExecutionMode EMode) {
uint32_t Arg;
N.get(Arg);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(BF, EMode, Arg)));
};

switch (EMode) {
case spv::ExecutionModeContractionOff:
BF->addExecutionMode(BM->add(
Expand Down Expand Up @@ -4528,56 +4543,49 @@ bool LLVMToSPIRVBase::transExecutionMode() {
}
} break;
case spv::ExecutionModeNoGlobalOffsetINTEL: {
if (BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes)) {
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityKernelAttributesINTEL);
}
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes))
break;
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityKernelAttributesINTEL);
} break;
case spv::ExecutionModeVecTypeHint:
case spv::ExecutionModeSubgroupSize:
case spv::ExecutionModeSubgroupsPerWorkgroup: {
unsigned X;
N.get(X);
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode), X)));
} break;
case spv::ExecutionModeSubgroupsPerWorkgroup:
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
break;
case spv::ExecutionModeNumSIMDWorkitemsINTEL:
case spv::ExecutionModeSchedulerTargetFmaxMhzINTEL:
case spv::ExecutionModeMaxWorkDimINTEL:
case spv::internal::ExecutionModeStreamingInterfaceINTEL: {
if (BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes)) {
unsigned X;
N.get(X);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), X)));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
}
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes))
break;
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityFPGAKernelAttributesINTEL);
} break;
case spv::ExecutionModeSharedLocalMemorySizeINTEL: {
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
break;
unsigned SLMSize;
N.get(SLMSize);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), SLMSize)));
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
} break;

case spv::ExecutionModeDenormPreserve:
case spv::ExecutionModeDenormFlushToZero:
case spv::ExecutionModeSignedZeroInfNanPreserve:
case spv::ExecutionModeRoundingModeRTE:
case spv::ExecutionModeRoundingModeRTZ: {
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_float_controls))
break;
unsigned TargetWidth;
N.get(TargetWidth);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_4)) {
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_4);
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
} else if (BM->isAllowedToUseExtension(
ExtensionID::SPV_KHR_float_controls)) {
BM->addExtension(ExtensionID::SPV_KHR_float_controls);
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
}
} break;
case spv::ExecutionModeRoundingModeRTPINTEL:
case spv::ExecutionModeRoundingModeRTNINTEL:
Expand All @@ -4586,10 +4594,7 @@ bool LLVMToSPIRVBase::transExecutionMode() {
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_float_controls2))
break;
unsigned TargetWidth;
N.get(TargetWidth);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), TargetWidth)));
AddSingleArgExecutionMode(static_cast<ExecutionMode>(EMode));
} break;
case spv::internal::ExecutionModeFastCompositeKernelINTEL: {
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
Expand Down
3 changes: 2 additions & 1 deletion llvm-spirv/lib/SPIRV/SPIRVWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ class LLVMToSPIRVBase {
bool isAnyFunctionReachableFromFunction(
const Function *FS,
const std::unordered_set<const Function *> Funcs) const;
void collectInputOutputVariables(SPIRVFunction *SF, Function *F);
std::vector<SPIRVId> collectEntryPointInterfaces(SPIRVFunction *BF,
Function *F);
};

class LLVMToSPIRVPass : public PassInfoMixin<LLVMToSPIRVPass> {
Expand Down
6 changes: 2 additions & 4 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVDecorate.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class SPIRVDecorateGeneric : public SPIRVAnnotationGeneric {

case DecorationMaxByteOffset:
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_1);
case DecorationUserSemantic:
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_4);

default:
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_0);
Expand Down Expand Up @@ -127,9 +129,6 @@ class SPIRVDecorate : public SPIRVDecorateGeneric {

llvm::Optional<ExtensionID> getRequiredExtension() const override {
switch (static_cast<size_t>(Dec)) {
case DecorationNoSignedWrap:
case DecorationNoUnsignedWrap:
return ExtensionID::SPV_KHR_no_integer_wrap_decoration;
case DecorationRegisterINTEL:
case DecorationMemoryINTEL:
case DecorationNumbanksINTEL:
Expand Down Expand Up @@ -246,7 +245,6 @@ class SPIRVDecorateLinkageAttr : public SPIRVDecorate {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat) {
Encoder << getString(Literals.cbegin(), Literals.cend() - 1);
Encoder.OS << " ";
Encoder << (SPIRVLinkageTypeKind)Literals.back();
} else
#endif
Expand Down
6 changes: 4 additions & 2 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,11 @@ void SPIRVEntryPoint::encode(spv_ostream &O) const {
}

void SPIRVEntryPoint::decode(std::istream &I) {
getDecoder(I) >> ExecModel >> Target >> Name >> Variables;
getDecoder(I) >> ExecModel >> Target >> Name;
Variables.resize(WordCount - FixedWC - getSizeInWords(Name) + 1);
getDecoder(I) >> Variables;
Module->setName(getOrCreateTarget(), Name);
Module->addEntryPoint(ExecModel, Target);
Module->addEntryPoint(ExecModel, Target, Name, Variables);
}

void SPIRVExecutionMode::encode(spv_ostream &O) const {
Expand Down
7 changes: 1 addition & 6 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ template <Op OC> class SPIRVAnnotation : public SPIRVAnnotationGeneric {

class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
public:
static const SPIRVWord FixedWC = 4;
SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind,
SPIRVId TheId, const std::string &TheName,
std::vector<SPIRVId> Variables);
Expand Down Expand Up @@ -843,12 +844,6 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {

llvm::Optional<ExtensionID> getRequiredExtension() const override {
switch (static_cast<unsigned>(Kind)) {
case CapabilityDenormPreserve:
case CapabilityDenormFlushToZero:
case CapabilitySignedZeroInfNanPreserve:
case CapabilityRoundingModeRTE:
case CapabilityRoundingModeRTZ:
return ExtensionID::SPV_KHR_float_controls;
case CapabilityRoundToInfinityINTEL:
case CapabilityFloatingPointModeINTEL:
case CapabilityFunctionFloatControlINTEL:
Expand Down
38 changes: 12 additions & 26 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,6 @@ class SPIRVModuleImpl : public SPIRVModule {
getValueTypes(const std::vector<SPIRVId> &) const override;
SPIRVMemoryModelKind getMemoryModel() const override { return MemoryModel; }
SPIRVConstant *getLiteralAsConstant(unsigned Literal) override;
unsigned getNumEntryPoints(SPIRVExecutionModelKind EM) const override {
auto Loc = EntryPointVec.find(EM);
if (Loc == EntryPointVec.end())
return 0;
return Loc->second.size();
}
SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind EM,
unsigned I) const override {
auto Loc = EntryPointVec.find(EM);
if (Loc == EntryPointVec.end())
return nullptr;
assert(I < Loc->second.size());
return get<SPIRVFunction>(Loc->second[I]);
}
unsigned getNumFunctions() const override { return FuncVec.size(); }
unsigned getNumVariables() const override { return VariableVec.size(); }
SourceLanguage getSourceLanguage(SPIRVWord *Ver = nullptr) const override {
Expand Down Expand Up @@ -215,8 +201,9 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVGroupMemberDecorate *
addGroupMemberDecorate(SPIRVDecorationGroup *Group,
const std::vector<SPIRVEntry *> &Targets) override;
void addEntryPoint(SPIRVExecutionModelKind ExecModel,
SPIRVId EntryPoint) override;
void addEntryPoint(SPIRVExecutionModelKind ExecModel, SPIRVId EntryPoint,
const std::string &Name,
const std::vector<SPIRVId> &Variables) override;
SPIRVForward *addForward(SPIRVType *Ty) override;
SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) override;
SPIRVFunction *addFunction(SPIRVFunction *) override;
Expand Down Expand Up @@ -494,11 +481,11 @@ class SPIRVModuleImpl : public SPIRVModule {
typedef std::vector<SPIRVGroupDecorateGeneric *> SPIRVGroupDecVec;
typedef std::vector<SPIRVAsmTargetINTEL *> SPIRVAsmTargetVector;
typedef std::vector<SPIRVAsmINTEL *> SPIRVAsmVector;
typedef std::vector<SPIRVEntryPoint *> SPIRVEntryPointVec;
typedef std::map<SPIRVId, SPIRVExtInstSetKind> SPIRVIdToInstructionSetMap;
std::map<SPIRVExtInstSetKind, SPIRVId> ExtInstSetIds;
typedef std::map<SPIRVId, SPIRVExtInstSetKind> SPIRVIdToBuiltinSetMap;
typedef std::map<SPIRVExecutionModelKind, SPIRVIdSet> SPIRVExecModelIdSetMap;
typedef std::map<SPIRVExecutionModelKind, SPIRVIdVec> SPIRVExecModelIdVecMap;
typedef std::unordered_map<std::string, SPIRVString *> SPIRVStringMap;
typedef std::map<SPIRVTypeStruct *, std::vector<std::pair<unsigned, SPIRVId>>>
SPIRVUnknownStructFieldMap;
Expand All @@ -525,7 +512,7 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVAsmTargetVector AsmTargetVec;
SPIRVAsmVector AsmVec;
SPIRVExecModelIdSetMap EntryPointSet;
SPIRVExecModelIdVecMap EntryPointVec;
SPIRVEntryPointVec EntryPointVec;
SPIRVStringMap StrMap;
SPIRVCapMap CapMap;
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
Expand Down Expand Up @@ -1012,11 +999,14 @@ SPIRVModuleImpl::addDecorate(SPIRVDecorateGeneric *Dec) {
}

void SPIRVModuleImpl::addEntryPoint(SPIRVExecutionModelKind ExecModel,
SPIRVId EntryPoint) {
SPIRVId EntryPoint, const std::string &Name,
const std::vector<SPIRVId> &Variables) {
assert(isValid(ExecModel) && "Invalid execution model");
assert(EntryPoint != SPIRVID_INVALID && "Invalid entry point");
auto *EP =
add(new SPIRVEntryPoint(this, ExecModel, EntryPoint, Name, Variables));
EntryPointVec.push_back(EP);
EntryPointSet[ExecModel].insert(EntryPoint);
EntryPointVec[ExecModel].push_back(EntryPoint);
addCapabilities(SPIRV::getCapability(ExecModel));
}

Expand Down Expand Up @@ -1850,14 +1840,10 @@ spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M) {

O << SPIRVMemoryModel(&M);

for (auto &I : MI.EntryPointVec)
for (auto &II : I.second)
O << SPIRVEntryPoint(&M, I.first, II, M.get<SPIRVFunction>(II)->getName(),
M.get<SPIRVFunction>(II)->getVariables());
O << MI.EntryPointVec;

for (auto &I : MI.EntryPointVec)
for (auto &II : I.second)
MI.get<SPIRVFunction>(II)->encodeExecutionModes(O);
MI.get<SPIRVFunction>(I->getTargetId())->encodeExecutionModes(O);

O << MI.StringVec;

Expand Down
7 changes: 3 additions & 4 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,11 @@ class SPIRVModule {
virtual const SPIRVCapMap &getCapability() const = 0;
virtual bool hasCapability(SPIRVCapabilityKind) const = 0;
virtual SPIRVExtInstSetKind getBuiltinSet(SPIRVId) const = 0;
virtual SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind,
unsigned) const = 0;
virtual std::set<std::string> &getExtension() = 0;
virtual SPIRVFunction *getFunction(unsigned) const = 0;
virtual SPIRVVariable *getVariable(unsigned) const = 0;
virtual SPIRVMemoryModelKind getMemoryModel() const = 0;
virtual unsigned getNumFunctions() const = 0;
virtual unsigned getNumEntryPoints(SPIRVExecutionModelKind) const = 0;
virtual unsigned getNumVariables() const = 0;
virtual SourceLanguage getSourceLanguage(SPIRVWord *) const = 0;
virtual std::set<std::string> &getSourceExtension() = 0;
Expand Down Expand Up @@ -213,7 +210,9 @@ class SPIRVModule {
const std::vector<SPIRVEntry *> &Targets) = 0;
virtual SPIRVGroupDecorateGeneric *
addGroupDecorateGeneric(SPIRVGroupDecorateGeneric *GDec) = 0;
virtual void addEntryPoint(SPIRVExecutionModelKind, SPIRVId) = 0;
virtual void addEntryPoint(SPIRVExecutionModelKind, SPIRVId,
const std::string &,
const std::vector<SPIRVId> &) = 0;
virtual SPIRVForward *addForward(SPIRVType *Ty) = 0;
virtual SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) = 0;
virtual SPIRVFunction *addFunction(SPIRVFunction *) = 0;
Expand Down
1 change: 1 addition & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ const SPIRVEncoder &operator<<(const SPIRVEncoder &O, const std::string &Str) {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat) {
writeQuotedString(O.OS, Str);
O.OS << " ";
return O;
}
#endif
Expand Down
Loading