Skip to content

Commit a6ca745

Browse files
authored
Fix SPV-IR for OpBuildNDRange (#1145)
It is translated to a function with unmangled name __spirv_BuildNDRange_{1|2|3}D with struct return parameter and array arguments, since translator only translates it properly to SPIR-V with this signature. _ND postfix is requred because array arguments are mangled in the same way, so if there was no postfix, translator would produce functions with same name for different dimensions.
1 parent 248ca68 commit a6ca745

File tree

11 files changed

+218
-199
lines changed

11 files changed

+218
-199
lines changed

docs/SPIRVRepresentationInLLVM.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,17 @@ The unmangled names of SPIR-V builtin GenericCastToPtrExplicit function follow t
150150
.. code-block:: c
151151
152152
__spirv_GenericCastToPtrExplicit_To{Global|Local|Private}
153-
154-
SPIR-V 1.1 Builtin CreatePipeFromPipeStorage Function Name
153+
154+
SPIR-V Builtin BuildNDRange Function Name
155+
----------------------------------------
156+
157+
The unmangled names of SPIR-V builtin BuildNDRange functions follow the convention:
158+
159+
.. code-block:: c
160+
161+
__spirv_{BuildNDRange}_{1|2|3}D
162+
163+
SPIR-V 1.1 Builtin CreatePipeFromPipeStorage Function Name
155164
----------------------------------------
156165

157166
The unmangled names of SPIR-V builtin CreatePipeFromPipeStorage function follow the convention:

lib/SPIRV/SPIRVInternal.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,23 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
10371037
// Transform all builtin variables into calls
10381038
bool lowerBuiltinVariablesToCalls(Module *M);
10391039

1040+
/// \brief Post-process OpenCL or SPIRV builtin function returning struct type.
1041+
///
1042+
/// Some builtin functions are translated to SPIR-V instructions with
1043+
/// struct type result, e.g. NDRange creation functions. Such functions
1044+
/// need to be post-processed to return the struct through sret argument.
1045+
bool postProcessBuiltinReturningStruct(Function *F);
1046+
1047+
/// \brief Post-process OpenCL or SPIRV builtin function having array argument.
1048+
///
1049+
/// These functions are translated to functions with array type argument
1050+
/// first, then post-processed to have pointer arguments.
1051+
bool postProcessBuiltinWithArrayArguments(Function *F, StringRef DemangledName);
1052+
1053+
bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);
1054+
1055+
bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);
1056+
10401057
} // namespace SPIRV
10411058

10421059
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 27 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -562,41 +562,6 @@ bool SPIRVToLLVM::isSPIRVCmpInstTransToLLVMInst(SPIRVInstruction *BI) const {
562562
return isCmpOpCode(OC) && !(OC >= OpLessOrGreater && OC <= OpUnordered);
563563
}
564564

565-
// TODO: Instead of direct translation to OCL we should always produce SPIR-V
566-
// friendly IR and apply lowering later if needed
567-
bool SPIRVToLLVM::isDirectlyTranslatedToOCL(Op OpCode) const {
568-
if (isSubgroupAvcINTELInstructionOpCode(OpCode))
569-
return false;
570-
if (isIntelSubgroupOpCode(OpCode))
571-
return false;
572-
if (OpCode == OpImageSampleExplicitLod || OpCode == OpSampledImage)
573-
return false;
574-
if (OpCode == OpImageWrite || OpCode == OpImageRead ||
575-
OpCode == OpImageQueryOrder || OpCode == OpImageQueryFormat ||
576-
OpCode == OpImageQueryLevels)
577-
return false;
578-
if (OpCode == OpGenericCastToPtrExplicit)
579-
return false;
580-
if (isEventOpCode(OpCode))
581-
return false;
582-
if (OpBitFieldInsert <= OpCode && OpCode <= OpBitReverse)
583-
return false;
584-
if (OpCode == OpEnqueueMarker || OpCode == OpGetDefaultQueue)
585-
return false;
586-
if (OCLSPIRVBuiltinMap::rfind(OpCode, nullptr)) {
587-
// Not every spirv opcode which is placed in OCLSPIRVBuiltinMap is
588-
// translated directly to OCL builtin. Some of them are translated
589-
// to LLVM representation without any modifications (SPIRV format of
590-
// instruction is represented in LLVM) and then its translated to
591-
// clang-consistent format in SPIRVToOCL pass.
592-
return !(isAtomicOpCode(OpCode) || isGroupOpCode(OpCode) ||
593-
isGroupNonUniformOpcode(OpCode) || isPipeOpCode(OpCode) ||
594-
isMediaBlockINTELOpcode(OpCode) || OpCode == OpGroupAsyncCopy ||
595-
OpCode == OpGroupWaitEvents);
596-
}
597-
return false;
598-
}
599-
600565
void SPIRVToLLVM::setName(llvm::Value *V, SPIRVValue *BV) {
601566
auto Name = BV->getName();
602567
if (!Name.empty() && (!V->hasName() || Name != V->getName()))
@@ -1117,102 +1082,6 @@ Value *SPIRVToLLVM::transCmpInst(SPIRVValue *BV, BasicBlock *BB, Function *F) {
11171082
return Inst;
11181083
}
11191084

1120-
bool SPIRVToLLVM::postProcessOCL() {
1121-
StringRef DemangledName;
1122-
SPIRVWord SrcLangVer = 0;
1123-
BM->getSourceLanguage(&SrcLangVer);
1124-
bool IsCpp = SrcLangVer == kOCLVer::CL21;
1125-
for (auto I = M->begin(), E = M->end(); I != E;) {
1126-
auto F = I++;
1127-
if (F->hasName() && F->isDeclaration()) {
1128-
LLVM_DEBUG(dbgs() << "[postProcessOCL sret] " << *F << '\n');
1129-
if (F->getReturnType()->isStructTy() &&
1130-
oclIsBuiltin(F->getName(), DemangledName, IsCpp)) {
1131-
if (!postProcessOCLBuiltinReturnStruct(&(*F)))
1132-
return false;
1133-
}
1134-
}
1135-
}
1136-
for (auto I = M->begin(), E = M->end(); I != E;) {
1137-
auto F = I++;
1138-
if (F->hasName() && F->isDeclaration()) {
1139-
LLVM_DEBUG(dbgs() << "[postProcessOCL array arg] " << *F << '\n');
1140-
if (hasArrayArg(&(*F)) &&
1141-
oclIsBuiltin(F->getName(), DemangledName, IsCpp))
1142-
if (!postProcessOCLBuiltinWithArrayArguments(&(*F), DemangledName))
1143-
return false;
1144-
}
1145-
}
1146-
return true;
1147-
}
1148-
1149-
bool SPIRVToLLVM::postProcessOCLBuiltinReturnStruct(Function *F) {
1150-
std::string Name = F->getName().str();
1151-
F->setName(Name + ".old");
1152-
for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
1153-
if (auto CI = dyn_cast<CallInst>(*I++)) {
1154-
auto ST = dyn_cast<StoreInst>(*(CI->user_begin()));
1155-
assert(ST);
1156-
std::vector<Type *> ArgTys;
1157-
getFunctionTypeParameterTypes(F->getFunctionType(), ArgTys);
1158-
ArgTys.insert(ArgTys.begin(),
1159-
PointerType::get(F->getReturnType(), SPIRAS_Private));
1160-
auto NewF =
1161-
getOrCreateFunction(M, Type::getVoidTy(*Context), ArgTys, Name);
1162-
NewF->setCallingConv(F->getCallingConv());
1163-
auto Args = getArguments(CI);
1164-
Args.insert(Args.begin(), ST->getPointerOperand());
1165-
auto NewCI = CallInst::Create(NewF, Args, CI->getName(), CI);
1166-
NewCI->setCallingConv(CI->getCallingConv());
1167-
ST->eraseFromParent();
1168-
CI->eraseFromParent();
1169-
}
1170-
}
1171-
F->eraseFromParent();
1172-
return true;
1173-
}
1174-
1175-
bool SPIRVToLLVM::postProcessOCLBuiltinWithArrayArguments(
1176-
Function *F, StringRef DemangledName) {
1177-
LLVM_DEBUG(dbgs() << "[postProcessOCLBuiltinWithArrayArguments] " << *F
1178-
<< '\n');
1179-
auto Attrs = F->getAttributes();
1180-
auto Name = F->getName();
1181-
mutateFunction(
1182-
F,
1183-
[=](CallInst *CI, std::vector<Value *> &Args) {
1184-
auto FBegin =
1185-
CI->getParent()->getParent()->begin()->getFirstInsertionPt();
1186-
for (auto &I : Args) {
1187-
auto T = I->getType();
1188-
if (!T->isArrayTy())
1189-
continue;
1190-
auto Alloca = new AllocaInst(T, 0, "", &(*FBegin));
1191-
new StoreInst(I, Alloca, false, CI);
1192-
auto Zero =
1193-
ConstantInt::getNullValue(Type::getInt32Ty(T->getContext()));
1194-
Value *Index[] = {Zero, Zero};
1195-
I = GetElementPtrInst::CreateInBounds(T, Alloca, Index, "", CI);
1196-
}
1197-
return Name.str();
1198-
},
1199-
nullptr, &Attrs);
1200-
return true;
1201-
}
1202-
1203-
CallInst *SPIRVToLLVM::postProcessOCLBuildNDRange(SPIRVInstruction *BI,
1204-
CallInst *CI,
1205-
const std::string &FuncName) {
1206-
assert(CI->getNumArgOperands() == 3);
1207-
auto GWS = CI->getArgOperand(0);
1208-
auto LWS = CI->getArgOperand(1);
1209-
auto GWO = CI->getArgOperand(2);
1210-
CI->setArgOperand(0, GWO);
1211-
CI->setArgOperand(1, GWS);
1212-
CI->setArgOperand(2, LWS);
1213-
return CI;
1214-
}
1215-
12161085
Type *SPIRVToLLVM::mapType(SPIRVType *BT, Type *T) {
12171086
SPIRVDBG(dbgs() << *T << '\n';)
12181087
TypeMap[BT] = T;
@@ -2569,14 +2438,17 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
25692438

25702439
default: {
25712440
auto OC = BV->getOpCode();
2572-
if (isSPIRVCmpInstTransToLLVMInst(static_cast<SPIRVInstruction *>(BV))) {
2441+
if (isSPIRVCmpInstTransToLLVMInst(static_cast<SPIRVInstruction *>(BV)))
25732442
return mapValue(BV, transCmpInst(BV, BB, F));
2574-
} else if (isDirectlyTranslatedToOCL(OC)) {
2575-
return mapValue(
2576-
BV, transOCLBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB));
2577-
} else if (isBinaryShiftLogicalBitwiseOpCode(OC) || isLogicalOpCode(OC)) {
2443+
2444+
if (OCLSPIRVBuiltinMap::rfind(OC, nullptr))
2445+
return mapValue(BV, transSPIRVBuiltinFromInst(
2446+
static_cast<SPIRVInstruction *>(BV), BB));
2447+
2448+
if (isBinaryShiftLogicalBitwiseOpCode(OC) || isLogicalOpCode(OC))
25782449
return mapValue(BV, transShiftLogicalBitwiseInst(BV, BB, F));
2579-
} else if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
2450+
2451+
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
25802452
auto BI = static_cast<SPIRVInstruction *>(BV);
25812453
Value *Inst = nullptr;
25822454
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion())
@@ -2929,8 +2801,6 @@ SPIRVToLLVM::transOCLBuiltinPostproc(SPIRVInstruction *BI, CallInst *CI,
29292801
}
29302802
if (OC == OpGenericPtrMemSemantics)
29312803
return BinaryOperator::CreateShl(CI, getInt32(M, 8), "", BB);
2932-
if (OC == OpBuildNDRange)
2933-
return postProcessOCLBuildNDRange(BI, CI, DemangledName);
29342804
if (SPIRVEnableStepExpansion &&
29352805
(DemangledName == "smoothstep" || DemangledName == "step"))
29362806
return expandOCLBuiltinWithScalarArg(CI, DemangledName);
@@ -3161,30 +3031,6 @@ SPIRVToLLVM::SPIRVToLLVM(Module *LLVMModule, SPIRVModule *TheSPIRVModule)
31613031
DbgTran.reset(new SPIRVToLLVMDbgTran(TheSPIRVModule, LLVMModule, this));
31623032
}
31633033

3164-
std::string SPIRVToLLVM::getOCLBuiltinName(SPIRVInstruction *BI) {
3165-
auto OC = BI->getOpCode();
3166-
if (OC == OpBuildNDRange) {
3167-
auto NDRangeInst = static_cast<SPIRVBuildNDRange *>(BI);
3168-
auto EleTy = ((NDRangeInst->getOperands())[0])->getType();
3169-
int Dim = EleTy->isTypeArray() ? EleTy->getArrayLength() : 1;
3170-
// cygwin does not have std::to_string
3171-
ostringstream OS;
3172-
OS << Dim;
3173-
assert((EleTy->isTypeInt() && Dim == 1) ||
3174-
(EleTy->isTypeArray() && Dim >= 2 && Dim <= 3));
3175-
return std::string(kOCLBuiltinName::NDRangePrefix) + OS.str() + "D";
3176-
}
3177-
3178-
return OCLSPIRVBuiltinMap::rmap(OC);
3179-
}
3180-
3181-
Instruction *SPIRVToLLVM::transOCLBuiltinFromInst(SPIRVInstruction *BI,
3182-
BasicBlock *BB) {
3183-
assert(BB && "Invalid BB");
3184-
auto FuncName = getOCLBuiltinName(BI);
3185-
return transBuiltinFromInst(FuncName, BI, BB);
3186-
}
3187-
31883034
std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
31893035
string Suffix = "";
31903036
if (BI->getOpCode() == OpCreatePipeFromPipeStorage) {
@@ -3231,6 +3077,17 @@ std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
32313077
llvm_unreachable("Invalid address space");
32323078
}
32333079
}
3080+
if (BI->getOpCode() == OpBuildNDRange) {
3081+
Suffix += kSPIRVPostfix::Divider;
3082+
auto *NDRangeInst = static_cast<SPIRVBuildNDRange *>(BI);
3083+
auto *EleTy = ((NDRangeInst->getOperands())[0])->getType();
3084+
int Dim = EleTy->isTypeArray() ? EleTy->getArrayLength() : 1;
3085+
assert((EleTy->isTypeInt() && Dim == 1) ||
3086+
(EleTy->isTypeArray() && Dim >= 2 && Dim <= 3));
3087+
ostringstream OS;
3088+
OS << Dim;
3089+
Suffix += OS.str() + "D";
3090+
}
32343091
return Suffix;
32353092
}
32363093

@@ -3319,8 +3176,13 @@ bool SPIRVToLLVM::translate() {
33193176
// as calls.
33203177
if (!lowerBuiltinVariablesToCalls(M))
33213178
return false;
3322-
if (!postProcessOCL())
3323-
return false;
3179+
if (BM->getDesiredBIsRepresentation() == BIsRepresentation::SPIRVFriendlyIR) {
3180+
SPIRVWord SrcLangVer = 0;
3181+
BM->getSourceLanguage(&SrcLangVer);
3182+
bool IsCpp = SrcLangVer == kOCLVer::CL21;
3183+
if (!postProcessBuiltinsReturningStruct(M, IsCpp))
3184+
return false;
3185+
}
33243186
eraseUselessFunctions(M);
33253187

33263188
DbgTran->addDbgInfoVersion();

lib/SPIRV/SPIRVReader.h

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class SPIRVToLLVM {
7878
SPIRVToLLVM(Module *LLVMModule, SPIRVModule *TheSPIRVModule);
7979

8080
static const StringSet<> BuiltInConstFunc;
81-
std::string getOCLBuiltinName(SPIRVInstruction *BI);
8281

8382
Type *transType(SPIRVType *BT, bool IsClassMember = false);
8483
std::string transTypeToOCLTypeName(SPIRVType *BT, bool IsSigned = true);
@@ -123,34 +122,8 @@ class SPIRVToLLVM {
123122
Value *transConvertInst(SPIRVValue *BV, Function *F, BasicBlock *BB);
124123
Instruction *transBuiltinFromInst(const std::string &FuncName,
125124
SPIRVInstruction *BI, BasicBlock *BB);
126-
Instruction *transOCLBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
127125
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
128126

129-
/// Post-process translated LLVM module for OpenCL.
130-
bool postProcessOCL();
131-
132-
/// \brief Post-process OpenCL builtin functions returning struct type.
133-
///
134-
/// Some OpenCL builtin functions are translated to SPIR-V instructions with
135-
/// struct type result, e.g. NDRange creation functions. Such functions
136-
/// need to be post-processed to return the struct through sret argument.
137-
bool postProcessOCLBuiltinReturnStruct(Function *F);
138-
139-
/// \brief Post-process OpenCL builtin functions having array argument.
140-
///
141-
/// These functions are translated to functions with array type argument
142-
/// first, then post-processed to have pointer arguments.
143-
bool postProcessOCLBuiltinWithArrayArguments(Function *F,
144-
StringRef DemangledName);
145-
146-
/// \brief Post-process OpBuildNDRange.
147-
/// OpBuildNDRange GlobalWorkSize, LocalWorkSize, GlobalWorkOffset
148-
/// =>
149-
/// call ndrange_XD(GlobalWorkOffset, GlobalWorkSize, LocalWorkSize)
150-
/// \return transformed call instruction.
151-
CallInst *postProcessOCLBuildNDRange(SPIRVInstruction *BI, CallInst *CI,
152-
const std::string &DemangledName);
153-
154127
/// \brief Expand OCL builtin functions with scalar argument, e.g.
155128
/// step, smoothstep.
156129
/// gentype func (fp edge, gentype x)

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
135135
visitCallSPIRVAvcINTELInstructionBuiltin(&CI, OC);
136136
return;
137137
}
138+
if (OC == OpBuildNDRange) {
139+
visitCallBuildNDRangeBuiltIn(&CI, OC, DemangledName);
140+
return;
141+
}
138142
if (OC == OpGenericCastToPtrExplicit) {
139143
visitCallGenericCastToPtrExplicitBuiltIn(&CI, OC);
140144
return;
@@ -576,6 +580,34 @@ void SPIRVToOCLBase::visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC) {
576580
},
577581
&Attrs);
578582
}
583+
void SPIRVToOCLBase::visitCallBuildNDRangeBuiltIn(CallInst *CI, Op OC,
584+
StringRef DemangledName) {
585+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
586+
mutateCallInstOCL(
587+
M, CI,
588+
[=](CallInst *Call, std::vector<Value *> &Args) {
589+
assert(Args.size() == 3);
590+
// OpenCL built-in has another order of parameters.
591+
auto *GlobalWorkSize = Args[0];
592+
auto *LocalWorkSize = Args[1];
593+
auto *GlobalWorkOffset = Args[2];
594+
Args[0] = GlobalWorkOffset;
595+
Args[1] = GlobalWorkSize;
596+
Args[2] = LocalWorkSize;
597+
// __spirv_BuildNDRange_nD, drop __spirv_
598+
StringRef S = DemangledName;
599+
S = S.drop_front(strlen(kSPIRVName::Prefix));
600+
SmallVector<StringRef, 8> Split;
601+
// BuildNDRange_nD
602+
S.split(Split, kSPIRVPostfix::Divider,
603+
/*MaxSplit=*/-1, /*KeepEmpty=*/false);
604+
assert(Split.size() >= 2 && "Invalid SPIRV function name");
605+
// Cut _nD and add it to function name.
606+
return std::string(kOCLBuiltinName::NDRangePrefix) +
607+
Split[1].substr(0, 3).str();
608+
},
609+
&Attrs);
610+
}
579611

580612
void SPIRVToOCLBase::visitCallGenericCastToPtrExplicitBuiltIn(CallInst *CI,
581613
Op OC) {

lib/SPIRV/SPIRVToOCL.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase> {
107107
/// to_{global|local|private} OCL builtin.
108108
void visitCallGenericCastToPtrExplicitBuiltIn(CallInst *CI, Op OC);
109109

110+
/// Transform __spirv_OpBuildINDRange_{1|2|3}D to
111+
/// ndrange_{1|2|3}D OCL builtin.
112+
void visitCallBuildNDRangeBuiltIn(CallInst *CI, Op OC,
113+
StringRef DemangledName);
114+
110115
/// Transform __spirv_*Convert_R{ReturnType}{_sat}{_rtp|_rtn|_rtz|_rte} to
111116
/// convert_{ReturnType}_{sat}{_rtp|_rtn|_rtz|_rte}
112117
/// example: <2 x i8> __spirv_SatConvertUToS(<2 x i32>) =>

lib/SPIRV/SPIRVToOCL12.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ bool SPIRVToOCL12Base::runSPIRVToOCL(Module &Module) {
5959

6060
visit(*M);
6161

62+
postProcessBuiltinsReturningStruct(M);
63+
postProcessBuiltinsWithArrayArguments(M);
64+
6265
eraseUselessFunctions(&Module);
6366

6467
LLVM_DEBUG(dbgs() << "After SPIRVToOCL12:\n" << *M);

0 commit comments

Comments
 (0)