Skip to content

Commit 24f81e3

Browse files
[Mutator] Migrate more miscellaneous calls to the mutator interface. (#1619)
1 parent 1aa3b17 commit 24f81e3

File tree

2 files changed

+170
-236
lines changed

2 files changed

+170
-236
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 67 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -423,49 +423,45 @@ void OCLToSPIRVBase::visitCallNDRange(CallInst *CI, StringRef DemangledName) {
423423
StringRef LenStr = DemangledName.substr(8, 1);
424424
auto Len = atoi(LenStr.data());
425425
assert(Len >= 1 && Len <= 3);
426+
// Translate ndrange_ND into differently named SPIR-V
427+
// decorated functions because they have array arugments
428+
// of different dimension which mangled the same way.
429+
std::string Postfix("_");
430+
Postfix += LenStr;
431+
Postfix += 'D';
432+
std::string FuncName = getSPIRVFuncName(OpBuildNDRange, Postfix);
433+
auto Mutator = mutateCallInst(CI, FuncName);
434+
426435
// SPIR-V ndrange structure requires 3 members in the following order:
427436
// global work offset
428437
// global work size
429438
// local work size
430439
// The arguments need to add missing members.
431-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
432-
mutateCallInstSPIRV(
433-
M, CI,
434-
[=](CallInst *, std::vector<Value *> &Args) {
435-
for (size_t I = 1, E = Args.size(); I != E; ++I)
436-
Args[I] = getScalarOrArray(Args[I], Len, CI);
437-
switch (Args.size()) {
438-
case 2: {
439-
// Has global work size.
440-
auto T = Args[1]->getType();
441-
auto C = getScalarOrArrayConstantInt(CI, T, Len, 0);
442-
Args.push_back(C);
443-
Args.push_back(C);
444-
} break;
445-
case 3: {
446-
// Has global and local work size.
447-
auto T = Args[1]->getType();
448-
Args.push_back(getScalarOrArrayConstantInt(CI, T, Len, 0));
449-
} break;
450-
case 4: {
451-
// Move offset arg to the end
452-
auto OffsetPos = Args.begin() + 1;
453-
Value *OffsetVal = *OffsetPos;
454-
Args.erase(OffsetPos);
455-
Args.push_back(OffsetVal);
456-
} break;
457-
default:
458-
assert(0 && "Invalid number of arguments");
459-
}
460-
// Translate ndrange_ND into differently named SPIR-V
461-
// decorated functions because they have array arugments
462-
// of different dimension which mangled the same way.
463-
std::string Postfix("_");
464-
Postfix += LenStr;
465-
Postfix += 'D';
466-
return getSPIRVFuncName(OpBuildNDRange, Postfix);
467-
},
468-
&Attrs);
440+
for (size_t I = 1, E = CI->arg_size(); I != E; ++I)
441+
Mutator.mapArg(I, [=](Value *V) { return getScalarOrArray(V, Len, CI); });
442+
switch (CI->arg_size()) {
443+
case 2: {
444+
// Has global work size.
445+
auto *T = Mutator.getArg(1)->getType();
446+
auto *C = getScalarOrArrayConstantInt(CI, T, Len, 0);
447+
Mutator.appendArg(C);
448+
Mutator.appendArg(C);
449+
break;
450+
}
451+
case 3: {
452+
// Has global and local work size.
453+
auto *T = Mutator.getArg(1)->getType();
454+
Mutator.appendArg(getScalarOrArrayConstantInt(CI, T, Len, 0));
455+
break;
456+
}
457+
case 4: {
458+
// Move offset arg to the end
459+
Mutator.moveArg(1, CI->arg_size() - 1);
460+
break;
461+
}
462+
default:
463+
assert(0 && "Invalid number of arguments");
464+
}
469465
}
470466

471467
void OCLToSPIRVBase::visitCallAsyncWorkGroupCopy(CallInst *CI,
@@ -1299,29 +1295,23 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
12991295
IsSecondSigned = (IsDot) ? (MangledName[MangledName.size() - 2] == 'i')
13001296
: (MangledName[MangledName.size() - 3] == 'i');
13011297
}
1302-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1303-
mutateCallInstSPIRV(
1304-
M, CI,
1305-
[=](CallInst *, std::vector<Value *> &Args) {
1306-
// If arguments are in order unsigned -> signed
1307-
// then the translator should swap them,
1308-
// so that the OpSUDotKHR can be used properly
1309-
if (IsFirstSigned == false && IsSecondSigned == true) {
1310-
std::swap(Args[0], Args[1]);
1311-
}
1312-
Op OC;
1313-
if (IsDot) {
1314-
OC = (IsFirstSigned != IsSecondSigned
1315-
? OpSUDot
1316-
: ((IsFirstSigned) ? OpSDot : OpUDot));
1317-
} else {
1318-
OC = (IsFirstSigned != IsSecondSigned
1319-
? OpSUDotAccSat
1320-
: ((IsFirstSigned) ? OpSDotAccSat : OpUDotAccSat));
1321-
}
1322-
return getSPIRVFuncName(OC);
1323-
},
1324-
&Attrs);
1298+
Op OC;
1299+
if (IsDot) {
1300+
OC =
1301+
(IsFirstSigned != IsSecondSigned ? OpSUDot
1302+
: ((IsFirstSigned) ? OpSDot : OpUDot));
1303+
} else {
1304+
OC = (IsFirstSigned != IsSecondSigned
1305+
? OpSUDotAccSat
1306+
: ((IsFirstSigned) ? OpSDotAccSat : OpUDotAccSat));
1307+
}
1308+
auto Mutator = mutateCallInst(CI, OC);
1309+
// If arguments are in order unsigned -> signed
1310+
// then the translator should swap them,
1311+
// so that the OpSUDotKHR can be used properly
1312+
if (IsFirstSigned == false && IsSecondSigned == true) {
1313+
Mutator.moveArg(1, 0);
1314+
}
13251315
}
13261316

13271317
void OCLToSPIRVBase::visitCallScalToVec(CallInst *CI, StringRef MangledName,
@@ -1362,31 +1352,22 @@ void OCLToSPIRVBase::visitCallScalToVec(CallInst *CI, StringRef MangledName,
13621352
ScalarPos.push_back(1);
13631353
}
13641354

1365-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1366-
mutateCallInstSPIRV(
1367-
M, CI,
1368-
[=](CallInst *, std::vector<Value *> &Args) {
1369-
Args.resize(VecPos.size() + ScalarPos.size());
1370-
for (auto I : VecPos) {
1371-
Args[I] = CI->getOperand(I);
1372-
}
1373-
auto VecElemCount =
1374-
cast<VectorType>(CI->getOperand(VecPos[0])->getType())
1375-
->getElementCount();
1376-
for (auto I : ScalarPos) {
1377-
Instruction *Inst = InsertElementInst::Create(
1378-
UndefValue::get(CI->getOperand(VecPos[0])->getType()),
1379-
CI->getOperand(I), getInt32(M, 0), "", CI);
1380-
Value *NewVec = new ShuffleVectorInst(
1381-
Inst, UndefValue::get(CI->getOperand(VecPos[0])->getType()),
1382-
ConstantVector::getSplat(VecElemCount, getInt32(M, 0)), "", CI);
1383-
1384-
Args[I] = NewVec;
1385-
}
1386-
return getSPIRVExtFuncName(SPIRVEIS_OpenCL,
1387-
getExtOp(MangledName, DemangledName));
1388-
},
1389-
&Attrs);
1355+
assert(CI->arg_size() == VecPos.size() + ScalarPos.size() &&
1356+
"Argument counts do not match up.");
1357+
1358+
Type *VecTy = CI->getOperand(VecPos[0])->getType();
1359+
auto VecElemCount = cast<VectorType>(VecTy)->getElementCount();
1360+
auto Mutator = mutateCallInst(
1361+
CI, getSPIRVExtFuncName(SPIRVEIS_OpenCL,
1362+
getExtOp(MangledName, DemangledName)));
1363+
for (auto I : ScalarPos)
1364+
Mutator.mapArg(I, [&](Value *V) {
1365+
Instruction *Inst = InsertElementInst::Create(UndefValue::get(VecTy), V,
1366+
getInt32(M, 0), "", CI);
1367+
return new ShuffleVectorInst(
1368+
Inst, UndefValue::get(VecTy),
1369+
ConstantVector::getSplat(VecElemCount, getInt32(M, 0)), "", CI);
1370+
});
13901371
}
13911372

13921373
void OCLToSPIRVBase::visitCallGetImageChannel(CallInst *CI,

0 commit comments

Comments
 (0)