Skip to content

Commit f470ec7

Browse files
authored
[ESIMD] Re-work loads from globals in sycl-post-link (#4718)
* [ESIMD] Re-work loads from globals in sycl-post-link 1) The re-work in the lowering of loads from globals was required because the previous implementation did not allow handling the loads from scalar globals. 2) Added lowering for __spirv_BuiltInSubgroupLocalInvocationId(), which must always return 0 for ESIMD. Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent 6493e6b commit f470ec7

File tree

2 files changed

+150
-145
lines changed

2 files changed

+150
-145
lines changed

llvm/lib/SYCLLowerIR/LowerESIMD.cpp

Lines changed: 120 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ namespace id = itanium_demangle;
4343

4444
#define SLM_BTI 254
4545

46+
#define MAX_DIMS 3
47+
4648
namespace {
4749
SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &);
4850
void generateKernelMetadata(Module &);
@@ -846,145 +848,131 @@ static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
846848
auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false);
847849
NewI = CastInst::Create(CastOpcode, NewI, OITy,
848850
NewI->getName() + ".cast.ty", OldI);
851+
NewI->setDebugLoc(OldI->getDebugLoc());
849852
}
850853
return NewI;
851854
}
852855

853-
static int getIndexForSuffix(StringRef Suff) {
854-
return llvm::StringSwitch<int>(Suff)
855-
.Case("x", 0)
856-
.Case("y", 1)
857-
.Case("z", 2)
858-
.Default(-1);
856+
/// Returns the index from the given extract element instruction \p EEI.
857+
/// It is checked here that the index is either 0, 1, or 2.
858+
static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
859+
Value *IndexV = EEI->getIndexOperand();
860+
uint64_t IndexValue = cast<ConstantInt>(IndexV)->getZExtValue();
861+
assert(IndexValue < MAX_DIMS &&
862+
"Extract element index should be either 0, 1, or 2");
863+
return IndexValue;
859864
}
860865

861-
// Helper function to convert extractelement instruction associated with the
862-
// load from SPIRV builtin global, into the GenX intrinsic that returns vector
863-
// of coordinates. It also generates required extractelement and cast
864-
// instructions. Example:
865-
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
866-
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
867-
// to <3 x i64> addrspace(4)*), align 32
868-
// %1 = extractelement <3 x i64> %0, i64 0
869-
//
870-
// =>
871-
//
872-
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
873-
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0
874-
// %local_id.x.cast.ty = zext i32 %local_id.x to i64
875-
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI,
876-
StringRef Suff,
877-
const std::string &IntrinName,
878-
StringRef ValueName) {
879-
std::string IntrName =
880-
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName;
881-
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
882-
LLVMContext &Ctx = EEI->getModule()->getContext();
883-
Type *I32Ty = Type::getInt32Ty(Ctx);
884-
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration(
885-
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
886-
Instruction *IntrI =
887-
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI);
888-
int ExtractIndex = getIndexForSuffix(Suff);
889-
assert(ExtractIndex != -1 && "Extract index is invalid.");
890-
Twine ExtractName = ValueName + Suff;
891-
892-
Instruction *ExtrI = ExtractElementInst::Create(
893-
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI);
894-
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI);
895-
if (EEI->getDebugLoc()) {
896-
IntrI->setDebugLoc(EEI->getDebugLoc());
897-
ExtrI->setDebugLoc(EEI->getDebugLoc());
898-
// It's OK if ExtrI and CastI is the same instruction
899-
CastI->setDebugLoc(EEI->getDebugLoc());
866+
/// Generates the call of GenX intrinsic \p IntrinName and inserts it
867+
/// right before the given extract element instruction \p EEI using the result
868+
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
869+
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
870+
static Instruction *generateGenXCall(ExtractElementInst *EEI,
871+
StringRef IntrinName, bool IsVectorCall) {
872+
uint64_t IndexValue = getIndexFromExtract(EEI);
873+
std::string Suffix =
874+
IsVectorCall
875+
? ".v3i32"
876+
: (Twine(".") + Twine(static_cast<char>('x' + IndexValue))).str();
877+
std::string FullIntrinName = (Twine(GenXIntrinsic::getGenXIntrinsicPrefix()) +
878+
Twine(IntrinName) + Suffix)
879+
.str();
880+
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(FullIntrinName);
881+
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
882+
Function *NewFDecl =
883+
IsVectorCall
884+
? GenXIntrinsic::getGenXDeclaration(
885+
EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS))
886+
: GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID);
887+
888+
std::string ResultName =
889+
(Twine(EEI->getNameOrAsOperand()) + "." + FullIntrinName).str();
890+
Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI);
891+
Inst->setDebugLoc(EEI->getDebugLoc());
892+
893+
if (IsVectorCall) {
894+
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
895+
std::string ExtractName =
896+
(Twine(Inst->getNameOrAsOperand()) + ".ext." + Twine(IndexValue)).str();
897+
Inst = ExtractElementInst::Create(Inst, ConstantInt::get(I32Ty, IndexValue),
898+
ExtractName, EEI);
899+
Inst->setDebugLoc(EEI->getDebugLoc());
900900
}
901-
return CastI;
901+
Inst = addCastInstIfNeeded(EEI, Inst);
902+
return Inst;
902903
}
903904

904-
// Helper function to convert extractelement instruction associated with the
905-
// load from SPIRV builtin global, into the GenX intrinsic. It also generates
906-
// required cast instructions. Example:
907-
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
908-
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
909-
// 32 %1 = extractelement <3 x i64> %0, i64 0
910-
// =>
911-
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
912-
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
913-
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
914-
// i32 %group.id.x to i64
915-
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI,
916-
StringRef Suff,
917-
const std::string &IntrinName) {
918-
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) +
919-
IntrinName + Suff.str();
920-
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
921-
Function *NewFDecl =
922-
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {});
923-
924-
Instruction *IntrI =
925-
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI);
926-
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI);
927-
if (EEI->getDebugLoc()) {
928-
IntrI->setDebugLoc(EEI->getDebugLoc());
929-
// It's OK if IntrI and CastI is the same instruction
930-
CastI->setDebugLoc(EEI->getDebugLoc());
905+
/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
906+
/// intrinsic(s). The users of \p LI may also be transformed if needed for
907+
/// def/use type correctness.
908+
/// The replaced instructions are stored into the given container
909+
/// \p InstsToErase.
910+
static void
911+
translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
912+
SmallVectorImpl<Instruction *> &InstsToErase) {
913+
// TODO: Implement support for the following intrinsics:
914+
// uint32_t __spirv_BuiltIn NumSubgroups;
915+
// uint32_t __spirv_BuiltIn SubgroupId;
916+
917+
// Translate those loads from _scalar_ SPIRV globals that can be replaced with
918+
// a const value here.
919+
// The loads from other scalar SPIRV globals may require insertion of GenX
920+
// calls before each user, which is done in the loop by users of 'LI' below.
921+
Value *NewInst = nullptr;
922+
if (SpirvGlobalName == "SubgroupLocalInvocationId") {
923+
NewInst = llvm::Constant::getNullValue(LI->getType());
924+
} else if (SpirvGlobalName == "SubgroupSize" ||
925+
SpirvGlobalName == "SubgroupMaxSize") {
926+
NewInst = llvm::Constant::getIntegerValue(LI->getType(),
927+
llvm::APInt(32, 1, true));
928+
}
929+
if (NewInst) {
930+
LI->replaceAllUsesWith(NewInst);
931+
InstsToErase.push_back(LI);
932+
return;
931933
}
932-
return CastI;
933-
}
934934

935-
// This function translates one occurence of SPIRV builtin use into GenX
936-
// intrinsic.
937-
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI,
938-
StringRef SpirvGlobalName) {
939-
Value *IndexV = EEI->getIndexOperand();
940-
assert(isa<ConstantInt>(IndexV) &&
941-
"Extract element index should be a constant");
935+
// Only loads from _vector_ SPIRV globals reach here now. Their users are
936+
// expected to be ExtractElementInst only, and they are replaced in this loop.
937+
// When loads from _scalar_ SPIRV globals are handled here as well, the users
938+
// will not be replaced by new instructions, but the GenX call replacing the
939+
// original load 'LI' should be inserted before each user.
940+
for (User *LU : LI->users()) {
941+
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
942+
NewInst = nullptr;
943+
944+
if (SpirvGlobalName == "WorkgroupSize") {
945+
NewInst = generateGenXCall(EEI, "local.size", true);
946+
} else if (SpirvGlobalName == "LocalInvocationId") {
947+
NewInst = generateGenXCall(EEI, "local.id", true);
948+
} else if (SpirvGlobalName == "WorkgroupId") {
949+
NewInst = generateGenXCall(EEI, "group.id", false);
950+
} else if (SpirvGlobalName == "GlobalInvocationId") {
951+
// GlobalId = LocalId + WorkGroupSize * GroupId
952+
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true);
953+
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
954+
Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false);
955+
Instruction *MulI =
956+
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
957+
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
958+
} else if (SpirvGlobalName == "GlobalSize") {
959+
// GlobalSize = WorkGroupSize * NumWorkGroups
960+
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
961+
Instruction *NumWGI = generateGenXCall(EEI, "group.count", true);
962+
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
963+
} else if (SpirvGlobalName == "GlobalOffset") {
964+
// TODO: Support GlobalOffset SPIRV intrinsics
965+
// Currently all users of load of GlobalOffset are replaced with 0.
966+
NewInst = llvm::Constant::getNullValue(EEI->getType());
967+
} else if (SpirvGlobalName == "NumWorkgroups") {
968+
NewInst = generateGenXCall(EEI, "group.count", true);
969+
}
942970

943-
// Get the suffix based on the index of extractelement instruction
944-
ConstantInt *IndexC = cast<ConstantInt>(IndexV);
945-
std::string Suff;
946-
if (IndexC->equalsInt(0))
947-
Suff = 'x';
948-
else if (IndexC->equalsInt(1))
949-
Suff = 'y';
950-
else if (IndexC->equalsInt(2))
951-
Suff = 'z';
952-
else
953-
assert(false && "Extract element index should be either 0, 1, or 2");
954-
955-
// Translate SPIRV into GenX intrinsic.
956-
if (SpirvGlobalName == "WorkgroupSize") {
957-
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
958-
} else if (SpirvGlobalName == "LocalInvocationId") {
959-
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
960-
} else if (SpirvGlobalName == "WorkgroupId") {
961-
return generateGenXForSpirv(EEI, Suff, "group.id.");
962-
} else if (SpirvGlobalName == "GlobalInvocationId") {
963-
// GlobalId = LocalId + WorkGroupSize * GroupId
964-
Instruction *LocalIdI =
965-
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
966-
Instruction *WGSizeI =
967-
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
968-
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id.");
969-
Instruction *MulI =
970-
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
971-
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
972-
} else if (SpirvGlobalName == "GlobalSize") {
973-
// GlobalSize = WorkGroupSize * NumWorkGroups
974-
Instruction *WGSizeI =
975-
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
976-
Instruction *NumWGI = generateVectorGenXForSpirv(
977-
EEI, Suff, "group.count.v3i32", "group_count.");
978-
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
979-
} else if (SpirvGlobalName == "GlobalOffset") {
980-
// TODO: Support GlobalOffset SPIRV intrinsics
981-
return llvm::Constant::getNullValue(EEI->getType());
982-
} else if (SpirvGlobalName == "NumWorkgroups") {
983-
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32",
984-
"group_count.");
971+
assert(NewInst && "Load from global SPIRV builtin was not translated");
972+
EEI->replaceAllUsesWith(NewInst);
973+
InstsToErase.push_back(EEI);
985974
}
986-
987-
return nullptr;
975+
InstsToErase.push_back(LI);
988976
}
989977

990978
static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
@@ -1370,8 +1358,7 @@ SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &M) {
13701358

13711359
} // namespace
13721360

1373-
PreservedAnalyses SYCLLowerESIMDPass::run(Module &M,
1374-
ModuleAnalysisManager &) {
1361+
PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, ModuleAnalysisManager &) {
13751362
generateKernelMetadata(M);
13761363
SmallPtrSet<Type *, 4> GVTS = collectGenXVolatileTypes(M);
13771364

@@ -1507,23 +1494,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
15071494

15081495
auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size();
15091496

1510-
// Go through all the uses of the load instruction from SPIRV builtin
1511-
// globals, which are required to be extractelement instructions.
1512-
// Translate each of them.
1513-
for (auto *LU : LI->users()) {
1514-
auto *EEI = dyn_cast<ExtractElementInst>(LU);
1515-
assert(EEI && "User of load from global SPIRV builtin is not an "
1516-
"extractelement instruction");
1517-
Value *TranslatedVal = translateSpirvGlobalUse(
1518-
EEI, SpirvGlobal->getName().drop_front(PrefLen));
1519-
assert(TranslatedVal &&
1520-
"Load from global SPIRV builtin was not translated");
1521-
EEI->replaceAllUsesWith(TranslatedVal);
1522-
ESIMDToErases.push_back(EEI);
1523-
}
1524-
// After all users of load were translated, we get rid of the load
1525-
// itself.
1526-
ESIMDToErases.push_back(LI);
1497+
// Translate all uses of the load instruction from SPIRV builtin global.
1498+
// Replaces the original global load and it is uses and stores the old
1499+
// instructions to ESIMDToErases.
1500+
translateSpirvGlobalUses(LI, SpirvGlobal->getName().drop_front(PrefLen),
1501+
ESIMDToErases);
15271502
}
15281503
}
15291504
// Now demangle and translate found ESIMD intrinsic calls

sycl/test/esimd/spirv_intrins_trans.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ size_t caller() {
1818

1919
size_t DoNotOpt;
2020
cl::sycl::buffer<size_t, 1> buf(&DoNotOpt, 1);
21+
uint32_t DoNotOpt32;
22+
cl::sycl::buffer<uint32_t, 1> buf32(&DoNotOpt32, 1);
2123

2224
size_t DoNotOptXYZ[3];
2325
cl::sycl::buffer<size_t, 1> bufXYZ(&DoNotOptXYZ[0], sycl::range<1>(3));
2426

2527
cl::sycl::queue().submit([&](cl::sycl::handler &cgh) {
2628
auto DoNotOptimize = buf.get_access<cl::sycl::access::mode::write>(cgh);
29+
auto DoNotOptimize32 = buf32.get_access<cl::sycl::access::mode::write>(cgh);
2730

2831
kernel<class kernel_GlobalInvocationId_x>([=]() SYCL_ESIMD_KERNEL {
2932
*DoNotOptimize.get_pointer() = __spirv_GlobalInvocationId_x();
@@ -213,6 +216,33 @@ size_t caller() {
213216
// CHECK: {{.*}} call i32 @llvm.genx.group.id.x()
214217
// CHECK: {{.*}} call i32 @llvm.genx.group.id.y()
215218
// CHECK: {{.*}} call i32 @llvm.genx.group.id.z()
219+
220+
kernel<class kernel_SubgroupLocalInvocationId>([=]() SYCL_ESIMD_KERNEL {
221+
*DoNotOptimize.get_pointer() = __spirv_SubgroupLocalInvocationId();
222+
*DoNotOptimize32.get_pointer() = __spirv_SubgroupLocalInvocationId() + 3;
223+
});
224+
// CHECK-LABEL: @{{.*}}kernel_SubgroupLocalInvocationId
225+
// CHECK: [[ZEXT0:%.*]] = zext i32 0 to i64
226+
// CHECK: store i64 [[ZEXT0]]
227+
// CHECK: add i32 0, 3
228+
229+
kernel<class kernel_SubgroupSize>([=]() SYCL_ESIMD_KERNEL {
230+
*DoNotOptimize.get_pointer() = __spirv_SubgroupSize();
231+
*DoNotOptimize32.get_pointer() = __spirv_SubgroupSize() + 7;
232+
});
233+
// CHECK-LABEL: @{{.*}}kernel_SubgroupSize
234+
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
235+
// CHECK: store i64 [[ZEXT0]]
236+
// CHECK: add i32 1, 7
237+
238+
kernel<class kernel_SubgroupMaxSize>([=]() SYCL_ESIMD_KERNEL {
239+
*DoNotOptimize.get_pointer() = __spirv_SubgroupMaxSize();
240+
*DoNotOptimize32.get_pointer() = __spirv_SubgroupMaxSize() + 9;
241+
});
242+
// CHECK-LABEL: @{{.*}}kernel_SubgroupMaxSize
243+
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
244+
// CHECK: store i64 [[ZEXT0]]
245+
// CHECK: add i32 1, 9
216246
});
217247
return DoNotOpt;
218248
}

0 commit comments

Comments
 (0)