-
Notifications
You must be signed in to change notification settings - Fork 787
Reenable "[ESIMD] Remove one of the uses on __SYCL_EXPLICIT_SIMD__ (#3242) #3311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,7 +79,7 @@ namespace { | |
// /^_Z(\d+)__esimd_\w+/ | ||
static constexpr char ESIMD_INTRIN_PREF0[] = "_Z"; | ||
static constexpr char ESIMD_INTRIN_PREF1[] = "__esimd_"; | ||
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_"; | ||
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_BuiltIn"; | ||
|
||
static constexpr char GENX_KERNEL_METADATA[] = "genx.kernels"; | ||
|
||
|
@@ -778,108 +778,122 @@ static int getIndexForSuffix(StringRef Suff) { | |
.Default(-1); | ||
} | ||
|
||
// Helper function to convert SPIRV intrinsic into GenX intrinsic, | ||
// that returns vector of coordinates. | ||
// Example: | ||
// %call = call spir_func i64 @_Z23__spirv_WorkgroupSize_xv() | ||
// => | ||
// %call.esimd = tail call <3 x i32> @llvm.genx.local.size.v3i32() | ||
// %wgsize.x = extractelement <3 x i32> %call.esimd, i32 0 | ||
// %wgsize.x.cast.ty = zext i32 %wgsize.x to i64 | ||
static Instruction *generateVectorGenXForSpirv(CallInst &CI, StringRef Suff, | ||
// Helper function to convert extractelement instruction associated with the | ||
// load from SPIRV builtin global, into the GenX intrinsic that returns vector | ||
// of coordinates. It also generates required extractelement and cast | ||
// instructions. Example: | ||
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast | ||
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId | ||
// to <3 x i64> addrspace(4)*), align 32 | ||
// %1 = extractelement <3 x i64> %0, i64 0 | ||
// | ||
// => | ||
// | ||
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32() | ||
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0 | ||
// %local_id.x.cast.ty = zext i32 %local_id.x to i64 | ||
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI, | ||
StringRef Suff, | ||
const std::string &IntrinName, | ||
StringRef ValueName) { | ||
std::string IntrName = | ||
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName; | ||
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName); | ||
LLVMContext &Ctx = CI.getModule()->getContext(); | ||
LLVMContext &Ctx = EEI->getModule()->getContext(); | ||
Type *I32Ty = Type::getInt32Ty(Ctx); | ||
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration( | ||
CI.getModule(), ID, {FixedVectorType::get(I32Ty, 3)}); | ||
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)}); | ||
Instruction *IntrI = | ||
IntrinsicInst::Create(NewFDecl, {}, CI.getName() + ".esimd", &CI); | ||
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI); | ||
int ExtractIndex = getIndexForSuffix(Suff); | ||
assert(ExtractIndex != -1 && "Extract index is invalid."); | ||
Twine ExtractName = ValueName + Suff; | ||
|
||
Instruction *ExtrI = ExtractElementInst::Create( | ||
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, &CI); | ||
Instruction *CastI = addCastInstIfNeeded(&CI, ExtrI); | ||
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI); | ||
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI); | ||
return CastI; | ||
} | ||
|
||
// Helper function to convert SPIRV intrinsic into GenX intrinsic, | ||
// that has exact mapping. | ||
// Example: | ||
// %call = call spir_func i64 @_Z21__spirv_WorkgroupId_xv() | ||
// => | ||
// %group.id.x = tail call i32 @llvm.genx.group.id.x() | ||
// %group.id.x.cast.ty = zext i32 %group.id.x to i64 | ||
static Instruction *generateGenXForSpirv(CallInst &CI, StringRef Suff, | ||
// Helper function to convert extractelement instruction associated with the | ||
// load from SPIRV builtin global, into the GenX intrinsic. It also generates | ||
// required cast instructions. Example: | ||
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> | ||
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align | ||
// 32 %1 = extractelement <3 x i64> %0, i64 0 | ||
// => | ||
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> | ||
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align | ||
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext | ||
// i32 %group.id.x to i64 | ||
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI, | ||
StringRef Suff, | ||
const std::string &IntrinName) { | ||
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + | ||
IntrinName + Suff.str(); | ||
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName); | ||
Function *NewFDecl = | ||
GenXIntrinsic::getGenXDeclaration(CI.getModule(), ID, {}); | ||
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {}); | ||
|
||
Instruction *IntrI = | ||
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), &CI); | ||
Instruction *CastI = addCastInstIfNeeded(&CI, IntrI); | ||
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI); | ||
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI); | ||
return CastI; | ||
} | ||
|
||
// This function translates SPIRV intrinsic into GenX intrinsic. | ||
// TODO: Currently, we do not support mixing SYCL and ESIMD kernels. | ||
// Later for ESIMD and SYCL kernels to coexist, we likely need to | ||
// clone call graph that lead from ESIMD kernel to SPIRV intrinsic and | ||
// translate SPIRV intrinsics to GenX intrinsics only in cloned subgraph. | ||
static void | ||
translateSpirvIntrinsic(CallInst *CI, StringRef SpirvIntrName, | ||
SmallVector<Instruction *, 8> &ESIMDToErases) { | ||
auto translateSpirvIntr = [&SpirvIntrName, &ESIMDToErases, | ||
CI](StringRef SpvIName, auto TranslateFunc) { | ||
if (SpirvIntrName.consume_front(SpvIName)) { | ||
Value *TranslatedV = TranslateFunc(*CI, SpirvIntrName.substr(1, 1)); | ||
CI->replaceAllUsesWith(TranslatedV); | ||
ESIMDToErases.push_back(CI); | ||
} | ||
}; | ||
// This function translates one occurence of SPIRV builtin use into GenX | ||
// intrinsic. | ||
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI, | ||
StringRef SpirvGlobalName) { | ||
Value *IndexV = EEI->getIndexOperand(); | ||
assert(isa<ConstantInt>(IndexV) && | ||
"Extract element index should be a constant"); | ||
|
||
translateSpirvIntr("WorkgroupSize", [](CallInst &CI, StringRef Suff) { | ||
return generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize."); | ||
}); | ||
translateSpirvIntr("LocalInvocationId", [](CallInst &CI, StringRef Suff) { | ||
return generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id."); | ||
}); | ||
translateSpirvIntr("WorkgroupId", [](CallInst &CI, StringRef Suff) { | ||
return generateGenXForSpirv(CI, Suff, "group.id."); | ||
}); | ||
translateSpirvIntr("GlobalInvocationId", [](CallInst &CI, StringRef Suff) { | ||
// Get the suffix based on the index of extractelement instruction | ||
ConstantInt *IndexC = cast<ConstantInt>(IndexV); | ||
std::string Suff; | ||
if (IndexC->equalsInt(0)) | ||
Suff = 'x'; | ||
else if (IndexC->equalsInt(1)) | ||
Suff = 'y'; | ||
else if (IndexC->equalsInt(2)) | ||
Suff = 'z'; | ||
else | ||
assert(false && "Extract element index should be either 0, 1, or 2"); | ||
|
||
// Translate SPIRV into GenX intrinsic. | ||
if (SpirvGlobalName == "WorkgroupSize") { | ||
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); | ||
} else if (SpirvGlobalName == "LocalInvocationId") { | ||
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id."); | ||
} else if (SpirvGlobalName == "WorkgroupId") { | ||
return generateGenXForSpirv(EEI, Suff, "group.id."); | ||
} else if (SpirvGlobalName == "GlobalInvocationId") { | ||
// GlobalId = LocalId + WorkGroupSize * GroupId | ||
Instruction *LocalIdI = | ||
generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id."); | ||
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id."); | ||
Instruction *WGSizeI = | ||
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize."); | ||
Instruction *GroupIdI = generateGenXForSpirv(CI, Suff, "group.id."); | ||
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); | ||
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id."); | ||
Instruction *MulI = | ||
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", &CI); | ||
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", &CI); | ||
}); | ||
translateSpirvIntr("GlobalSize", [](CallInst &CI, StringRef Suff) { | ||
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI); | ||
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI); | ||
} else if (SpirvGlobalName == "GlobalSize") { | ||
// GlobalSize = WorkGroupSize * NumWorkGroups | ||
Instruction *WGSizeI = | ||
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize."); | ||
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize."); | ||
Instruction *NumWGI = generateVectorGenXForSpirv( | ||
CI, Suff, "group.count.v3i32", "group_count."); | ||
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", &CI); | ||
}); | ||
// TODO: Support GlobalOffset SPIRV intrinsics | ||
translateSpirvIntr("GlobalOffset", [](CallInst &CI, StringRef Suff) { | ||
return llvm::Constant::getNullValue(CI.getType()); | ||
}); | ||
translateSpirvIntr("NumWorkgroups", [](CallInst &CI, StringRef Suff) { | ||
return generateVectorGenXForSpirv(CI, Suff, "group.count.v3i32", | ||
EEI, Suff, "group.count.v3i32", "group_count."); | ||
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI); | ||
} else if (SpirvGlobalName == "GlobalOffset") { | ||
// TODO: Support GlobalOffset SPIRV intrinsics | ||
return llvm::Constant::getNullValue(EEI->getType()); | ||
} else if (SpirvGlobalName == "NumWorkgroups") { | ||
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32", | ||
"group_count."); | ||
}); | ||
} | ||
|
||
return nullptr; | ||
} | ||
|
||
static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc, | ||
|
@@ -1272,68 +1286,102 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F, | |
|
||
auto *CI = dyn_cast<CallInst>(&I); | ||
Function *Callee = nullptr; | ||
if (!CI || !(Callee = CI->getCalledFunction())) | ||
continue; | ||
StringRef Name = Callee->getName(); | ||
if (CI && (Callee = CI->getCalledFunction())) { | ||
|
||
// See if the Name represents an ESIMD intrinsic and demangle only if it | ||
// does. | ||
if (!Name.consume_front(ESIMD_INTRIN_PREF0)) | ||
continue; | ||
// now skip the digits | ||
Name = Name.drop_while([](char C) { return std::isdigit(C); }); | ||
|
||
// process ESIMD builtins that go through special handling instead of | ||
// the translation procedure | ||
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) { | ||
// tag the kernel with meta-data SLMSize, and remove this builtin | ||
translateSLMInit(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
if (Name.startswith("__esimd_pack_mask")) { | ||
translatePackMask(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
if (Name.startswith("__esimd_unpack_mask")) { | ||
translateUnPackMask(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
// If vload/vstore is not about the vector-types used by | ||
// those globals marked as genx_volatile, We can translate | ||
// them directly into generic load/store inst. In this way | ||
// those insts can be optimized by llvm ASAP. | ||
if (Name.startswith("__esimd_vload")) { | ||
if (translateVLoad(*CI, GVTS)) { | ||
StringRef Name = Callee->getName(); | ||
|
||
// See if the Name represents an ESIMD intrinsic and demangle only if it | ||
// does. | ||
if (!Name.consume_front(ESIMD_INTRIN_PREF0)) | ||
continue; | ||
// now skip the digits | ||
Name = Name.drop_while([](char C) { return std::isdigit(C); }); | ||
|
||
// process ESIMD builtins that go through special handling instead of | ||
// the translation procedure | ||
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) { | ||
// tag the kernel with meta-data SLMSize, and remove this builtin | ||
translateSLMInit(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
} | ||
if (Name.startswith("__esimd_vstore")) { | ||
if (translateVStore(*CI, GVTS)) { | ||
if (Name.startswith("__esimd_pack_mask")) { | ||
translatePackMask(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
} | ||
if (Name.startswith("__esimd_unpack_mask")) { | ||
translateUnPackMask(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
// If vload/vstore is not about the vector-types used by | ||
// those globals marked as genx_volatile, We can translate | ||
// them directly into generic load/store inst. In this way | ||
// those insts can be optimized by llvm ASAP. | ||
if (Name.startswith("__esimd_vload")) { | ||
if (translateVLoad(*CI, GVTS)) { | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
} | ||
if (Name.startswith("__esimd_vstore")) { | ||
if (translateVStore(*CI, GVTS)) { | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
} | ||
|
||
if (Name.startswith("__esimd_get_value")) { | ||
translateGetValue(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
if (Name.startswith("__esimd_get_value")) { | ||
translateGetValue(*CI); | ||
ESIMDToErases.push_back(CI); | ||
continue; | ||
} | ||
|
||
if (Name.consume_front(SPIRV_INTRIN_PREF)) { | ||
translateSpirvIntrinsic(CI, Name, ESIMDToErases); | ||
// For now: if no match, just let it go untranslated. | ||
continue; | ||
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1)) | ||
continue; | ||
// this is ESIMD intrinsic - record for later translation | ||
ESIMDIntrCalls.push_back(CI); | ||
} | ||
|
||
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1)) | ||
continue; | ||
// this is ESIMD intrinsic - record for later translation | ||
ESIMDIntrCalls.push_back(CI); | ||
// Translate loads from SPIRV builtin globals into GenX intrinsics | ||
auto *LI = dyn_cast<LoadInst>(&I); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the following would be more efficient from compile-time and code simplicity perspective and slightly more reliable:
This can be considered a Nit. But please add a TODO if you decide not to implement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point. I agree to reimplement it, but I would prefer to do it in a separate patch if you don't mind. We talked about splitting the LowerESIMD pass into two: ModulePass and FunctionPass. I think it would be appropriate to do those two changes together. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, sounds good |
||
if (LI) { | ||
Value *LoadPtrOp = LI->getPointerOperand(); | ||
Value *SpirvGlobal = nullptr; | ||
// Look through casts to find SPIRV builtin globals | ||
auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp); | ||
if (CE) { | ||
assert(CE->isCast() && "ConstExpr should be a cast"); | ||
SpirvGlobal = CE->getOperand(0); | ||
} else { | ||
SpirvGlobal = LoadPtrOp; | ||
} | ||
|
||
if (!isa<GlobalVariable>(SpirvGlobal) || | ||
!SpirvGlobal->getName().startswith(SPIRV_INTRIN_PREF)) | ||
continue; | ||
|
||
auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size(); | ||
|
||
// Go through all the uses of the load instruction from SPIRV builtin | ||
// globals, which are required to be extractelement instructions. | ||
// Translate each of them. | ||
for (auto *LU : LI->users()) { | ||
auto *EEI = dyn_cast<ExtractElementInst>(LU); | ||
assert(EEI && "User of load from global SPIRV builtin is not an " | ||
"extractelement instruction"); | ||
Value *TranslatedVal = translateSpirvGlobalUse( | ||
EEI, SpirvGlobal->getName().drop_front(PrefLen)); | ||
assert(TranslatedVal && | ||
"Load from global SPIRV builtin was not translated"); | ||
EEI->replaceAllUsesWith(TranslatedVal); | ||
ESIMDToErases.push_back(EEI); | ||
} | ||
// After all users of load were translated, we get rid of the load | ||
// itself. | ||
ESIMDToErases.push_back(LI); | ||
} | ||
} | ||
// Now demangle and translate found ESIMD intrinsic calls | ||
for (auto *CI : ESIMDIntrCalls) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: llvm_unrecheable seems a better fit for the purpose