Skip to content

Reenable "[ESIMD] Remove one of the uses on __SYCL_EXPLICIT_SIMD__ (#3242) #3311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 169 additions & 121 deletions llvm/lib/SYCLLowerIR/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace {
// /^_Z(\d+)__esimd_\w+/
static constexpr char ESIMD_INTRIN_PREF0[] = "_Z";
static constexpr char ESIMD_INTRIN_PREF1[] = "__esimd_";
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_";
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_BuiltIn";

static constexpr char GENX_KERNEL_METADATA[] = "genx.kernels";

Expand Down Expand Up @@ -778,108 +778,122 @@ static int getIndexForSuffix(StringRef Suff) {
.Default(-1);
}

// Helper function to convert SPIRV intrinsic into GenX intrinsic,
// that returns vector of coordinates.
// Example:
// %call = call spir_func i64 @_Z23__spirv_WorkgroupSize_xv()
// =>
// %call.esimd = tail call <3 x i32> @llvm.genx.local.size.v3i32()
// %wgsize.x = extractelement <3 x i32> %call.esimd, i32 0
// %wgsize.x.cast.ty = zext i32 %wgsize.x to i64
static Instruction *generateVectorGenXForSpirv(CallInst &CI, StringRef Suff,
// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic that returns vector
// of coordinates. It also generates required extractelement and cast
// instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
// to <3 x i64> addrspace(4)*), align 32
// %1 = extractelement <3 x i64> %0, i64 0
//
// =>
//
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0
// %local_id.x.cast.ty = zext i32 %local_id.x to i64
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName,
StringRef ValueName) {
std::string IntrName =
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName;
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
LLVMContext &Ctx = CI.getModule()->getContext();
LLVMContext &Ctx = EEI->getModule()->getContext();
Type *I32Ty = Type::getInt32Ty(Ctx);
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration(
CI.getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, CI.getName() + ".esimd", &CI);
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI);
int ExtractIndex = getIndexForSuffix(Suff);
assert(ExtractIndex != -1 && "Extract index is invalid.");
Twine ExtractName = ValueName + Suff;

Instruction *ExtrI = ExtractElementInst::Create(
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, &CI);
Instruction *CastI = addCastInstIfNeeded(&CI, ExtrI);
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI);
return CastI;
}

// Helper function to convert SPIRV intrinsic into GenX intrinsic,
// that has exact mapping.
// Example:
// %call = call spir_func i64 @_Z21__spirv_WorkgroupId_xv()
// =>
// %group.id.x = tail call i32 @llvm.genx.group.id.x()
// %group.id.x.cast.ty = zext i32 %group.id.x to i64
static Instruction *generateGenXForSpirv(CallInst &CI, StringRef Suff,
// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic. It also generates
// required cast instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %1 = extractelement <3 x i64> %0, i64 0
// =>
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
// i32 %group.id.x to i64
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName) {
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) +
IntrinName + Suff.str();
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
Function *NewFDecl =
GenXIntrinsic::getGenXDeclaration(CI.getModule(), ID, {});
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {});

Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), &CI);
Instruction *CastI = addCastInstIfNeeded(&CI, IntrI);
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI);
return CastI;
}

// This function translates SPIRV intrinsic into GenX intrinsic.
// TODO: Currently, we do not support mixing SYCL and ESIMD kernels.
// Later for ESIMD and SYCL kernels to coexist, we likely need to
// clone call graph that lead from ESIMD kernel to SPIRV intrinsic and
// translate SPIRV intrinsics to GenX intrinsics only in cloned subgraph.
static void
translateSpirvIntrinsic(CallInst *CI, StringRef SpirvIntrName,
SmallVector<Instruction *, 8> &ESIMDToErases) {
auto translateSpirvIntr = [&SpirvIntrName, &ESIMDToErases,
CI](StringRef SpvIName, auto TranslateFunc) {
if (SpirvIntrName.consume_front(SpvIName)) {
Value *TranslatedV = TranslateFunc(*CI, SpirvIntrName.substr(1, 1));
CI->replaceAllUsesWith(TranslatedV);
ESIMDToErases.push_back(CI);
}
};
// This function translates one occurence of SPIRV builtin use into GenX
// intrinsic.
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI,
StringRef SpirvGlobalName) {
Value *IndexV = EEI->getIndexOperand();
assert(isa<ConstantInt>(IndexV) &&
"Extract element index should be a constant");

translateSpirvIntr("WorkgroupSize", [](CallInst &CI, StringRef Suff) {
return generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
});
translateSpirvIntr("LocalInvocationId", [](CallInst &CI, StringRef Suff) {
return generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id.");
});
translateSpirvIntr("WorkgroupId", [](CallInst &CI, StringRef Suff) {
return generateGenXForSpirv(CI, Suff, "group.id.");
});
translateSpirvIntr("GlobalInvocationId", [](CallInst &CI, StringRef Suff) {
// Get the suffix based on the index of extractelement instruction
ConstantInt *IndexC = cast<ConstantInt>(IndexV);
std::string Suff;
if (IndexC->equalsInt(0))
Suff = 'x';
else if (IndexC->equalsInt(1))
Suff = 'y';
else if (IndexC->equalsInt(2))
Suff = 'z';
else
assert(false && "Extract element index should be either 0, 1, or 2");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: llvm_unrecheable seems a better fit for the purpose


// Translate SPIRV into GenX intrinsic.
if (SpirvGlobalName == "WorkgroupSize") {
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
} else if (SpirvGlobalName == "LocalInvocationId") {
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
} else if (SpirvGlobalName == "WorkgroupId") {
return generateGenXForSpirv(EEI, Suff, "group.id.");
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI =
generateVectorGenXForSpirv(CI, Suff, "local.id.v3i32", "local_id.");
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
Instruction *WGSizeI =
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
Instruction *GroupIdI = generateGenXForSpirv(CI, Suff, "group.id.");
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id.");
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", &CI);
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", &CI);
});
translateSpirvIntr("GlobalSize", [](CallInst &CI, StringRef Suff) {
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateVectorGenXForSpirv(CI, Suff, "local.size.v3i32", "wgsize.");
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *NumWGI = generateVectorGenXForSpirv(
CI, Suff, "group.count.v3i32", "group_count.");
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", &CI);
});
// TODO: Support GlobalOffset SPIRV intrinsics
translateSpirvIntr("GlobalOffset", [](CallInst &CI, StringRef Suff) {
return llvm::Constant::getNullValue(CI.getType());
});
translateSpirvIntr("NumWorkgroups", [](CallInst &CI, StringRef Suff) {
return generateVectorGenXForSpirv(CI, Suff, "group.count.v3i32",
EEI, Suff, "group.count.v3i32", "group_count.");
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
return llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32",
"group_count.");
});
}

return nullptr;
}

static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
Expand Down Expand Up @@ -1272,68 +1286,102 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,

auto *CI = dyn_cast<CallInst>(&I);
Function *Callee = nullptr;
if (!CI || !(Callee = CI->getCalledFunction()))
continue;
StringRef Name = Callee->getName();
if (CI && (Callee = CI->getCalledFunction())) {

// See if the Name represents an ESIMD intrinsic and demangle only if it
// does.
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
continue;
// now skip the digits
Name = Name.drop_while([](char C) { return std::isdigit(C); });

// process ESIMD builtins that go through special handling instead of
// the translation procedure
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) {
// tag the kernel with meta-data SLMSize, and remove this builtin
translateSLMInit(*CI);
ESIMDToErases.push_back(CI);
continue;
}
if (Name.startswith("__esimd_pack_mask")) {
translatePackMask(*CI);
ESIMDToErases.push_back(CI);
continue;
}
if (Name.startswith("__esimd_unpack_mask")) {
translateUnPackMask(*CI);
ESIMDToErases.push_back(CI);
continue;
}
// If vload/vstore is not about the vector-types used by
// those globals marked as genx_volatile, We can translate
// them directly into generic load/store inst. In this way
// those insts can be optimized by llvm ASAP.
if (Name.startswith("__esimd_vload")) {
if (translateVLoad(*CI, GVTS)) {
StringRef Name = Callee->getName();

// See if the Name represents an ESIMD intrinsic and demangle only if it
// does.
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
continue;
// now skip the digits
Name = Name.drop_while([](char C) { return std::isdigit(C); });

// process ESIMD builtins that go through special handling instead of
// the translation procedure
if (Name.startswith("N2cl4sycl5INTEL3gpu8slm_init")) {
// tag the kernel with meta-data SLMSize, and remove this builtin
translateSLMInit(*CI);
ESIMDToErases.push_back(CI);
continue;
}
}
if (Name.startswith("__esimd_vstore")) {
if (translateVStore(*CI, GVTS)) {
if (Name.startswith("__esimd_pack_mask")) {
translatePackMask(*CI);
ESIMDToErases.push_back(CI);
continue;
}
}
if (Name.startswith("__esimd_unpack_mask")) {
translateUnPackMask(*CI);
ESIMDToErases.push_back(CI);
continue;
}
// If vload/vstore is not about the vector-types used by
// those globals marked as genx_volatile, We can translate
// them directly into generic load/store inst. In this way
// those insts can be optimized by llvm ASAP.
if (Name.startswith("__esimd_vload")) {
if (translateVLoad(*CI, GVTS)) {
ESIMDToErases.push_back(CI);
continue;
}
}
if (Name.startswith("__esimd_vstore")) {
if (translateVStore(*CI, GVTS)) {
ESIMDToErases.push_back(CI);
continue;
}
}

if (Name.startswith("__esimd_get_value")) {
translateGetValue(*CI);
ESIMDToErases.push_back(CI);
continue;
}
if (Name.startswith("__esimd_get_value")) {
translateGetValue(*CI);
ESIMDToErases.push_back(CI);
continue;
}

if (Name.consume_front(SPIRV_INTRIN_PREF)) {
translateSpirvIntrinsic(CI, Name, ESIMDToErases);
// For now: if no match, just let it go untranslated.
continue;
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
continue;
// this is ESIMD intrinsic - record for later translation
ESIMDIntrCalls.push_back(CI);
}

if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
continue;
// this is ESIMD intrinsic - record for later translation
ESIMDIntrCalls.push_back(CI);
// Translate loads from SPIRV builtin globals into GenX intrinsics
auto *LI = dyn_cast<LoadInst>(&I);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the following would be more efficient from compile-time and code simplicity perspective and slightly more reliable:

  • iterating through M.global_begin()/M.global_end() and finding SPIRV globals among them
  • iterating through each global's uses and applying the esimd translation t-form.

This can be considered a Nit. But please add a TODO if you decide not to implement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I agree to reimplement it, but I would prefer to do it in a separate patch if you don't mind. We talked about splitting the LowerESIMD pass into two: ModulePass and FunctionPass. I think it would be appropriate to do those two changes together.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good

if (LI) {
Value *LoadPtrOp = LI->getPointerOperand();
Value *SpirvGlobal = nullptr;
// Look through casts to find SPIRV builtin globals
auto *CE = dyn_cast<ConstantExpr>(LoadPtrOp);
if (CE) {
assert(CE->isCast() && "ConstExpr should be a cast");
SpirvGlobal = CE->getOperand(0);
} else {
SpirvGlobal = LoadPtrOp;
}

if (!isa<GlobalVariable>(SpirvGlobal) ||
!SpirvGlobal->getName().startswith(SPIRV_INTRIN_PREF))
continue;

auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size();

// Go through all the uses of the load instruction from SPIRV builtin
// globals, which are required to be extractelement instructions.
// Translate each of them.
for (auto *LU : LI->users()) {
auto *EEI = dyn_cast<ExtractElementInst>(LU);
assert(EEI && "User of load from global SPIRV builtin is not an "
"extractelement instruction");
Value *TranslatedVal = translateSpirvGlobalUse(
EEI, SpirvGlobal->getName().drop_front(PrefLen));
assert(TranslatedVal &&
"Load from global SPIRV builtin was not translated");
EEI->replaceAllUsesWith(TranslatedVal);
ESIMDToErases.push_back(EEI);
}
// After all users of load were translated, we get rid of the load
// itself.
ESIMDToErases.push_back(LI);
}
}
// Now demangle and translate found ESIMD intrinsic calls
for (auto *CI : ESIMDIntrCalls) {
Expand Down
11 changes: 0 additions & 11 deletions llvm/test/SYCLLowerIR/esimd_lower_intrins.ll
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,6 @@ define dso_local spir_kernel void @FUNC_30() {
; CHECK-NEXT: ret void
}

define dso_local spir_kernel void @FUNC_31() {
; CHECK: define dso_local spir_kernel void @FUNC_31()
%call = call spir_func i64 @_Z27__spirv_LocalInvocationId_xv()
; CHECK-NEXT: %call.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
; CHECK-NEXT: %local_id.x = extractelement <3 x i32> %call.esimd, i32 0
; CHECK-NEXT: %local_id.x.cast.ty = zext i32 %local_id.x to i64
ret void
; CHECK-NEXT: ret void
}

define dso_local spir_func <16 x i32> @FUNC_32() {
%a_1 = alloca <16 x i32>
%1 = load <16 x i32>, <16 x i32>* %a_1
Expand Down Expand Up @@ -318,7 +308,6 @@ define dso_local spir_func <16 x i32> @FUNC_44() {
ret <16 x i32> %ret_val
}

declare dso_local spir_func i64 @_Z27__spirv_LocalInvocationId_xv()
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic0ILN2cm3gen14CmAtomicOpTypeE2EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeENS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i16> %1)
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic1ILN2cm3gen14CmAtomicOpTypeE0EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeES7_NS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i32> %1, <32 x i16> %2)
declare dso_local spir_func <32 x i32> @_Z20__esimd_flat_atomic2ILN2cm3gen14CmAtomicOpTypeE7EjLi32ELNS1_9CacheHintE0ELS3_0EENS1_13__vector_typeIT0_XT1_EE4typeENS4_IyXT1_EE4typeES7_S7_NS4_ItXT1_EE4typeE(<32 x i64> %0, <32 x i32> %1, <32 x i32> %2, <32 x i16> %3)
Expand Down
Loading