Skip to content

[ESIMD] Re-work loads from globals in sycl-post-link #4718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 9, 2021
265 changes: 120 additions & 145 deletions llvm/lib/SYCLLowerIR/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace id = itanium_demangle;

#define SLM_BTI 254

#define MAX_DIMS 3

namespace {
SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &);
void generateKernelMetadata(Module &);
Expand Down Expand Up @@ -846,145 +848,131 @@ static Instruction *addCastInstIfNeeded(Instruction *OldI, Instruction *NewI) {
auto CastOpcode = CastInst::getCastOpcode(NewI, false, OITy, false);
NewI = CastInst::Create(CastOpcode, NewI, OITy,
NewI->getName() + ".cast.ty", OldI);
NewI->setDebugLoc(OldI->getDebugLoc());
}
return NewI;
}

static int getIndexForSuffix(StringRef Suff) {
return llvm::StringSwitch<int>(Suff)
.Case("x", 0)
.Case("y", 1)
.Case("z", 2)
.Default(-1);
/// Returns the index from the given extract element instruction \p EEI.
/// It is checked here that the index is either 0, 1, or 2.
static uint64_t getIndexFromExtract(ExtractElementInst *EEI) {
Value *IndexV = EEI->getIndexOperand();
uint64_t IndexValue = cast<ConstantInt>(IndexV)->getZExtValue();
assert(IndexValue < MAX_DIMS &&
"Extract element index should be either 0, 1, or 2");
return IndexValue;
}

// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic that returns vector
// of coordinates. It also generates required extractelement and cast
// instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast
// (<3 x i64> addrspace(1)* @__spirv_BuiltInLocalInvocationId
// to <3 x i64> addrspace(4)*), align 32
// %1 = extractelement <3 x i64> %0, i64 0
//
// =>
//
// %.esimd = call <3 x i32> @llvm.genx.local.id.v3i32()
// %local_id.x = extractelement <3 x i32> %.esimd, i32 0
// %local_id.x.cast.ty = zext i32 %local_id.x to i64
static Instruction *generateVectorGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName,
StringRef ValueName) {
std::string IntrName =
std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) + IntrinName;
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
LLVMContext &Ctx = EEI->getModule()->getContext();
Type *I32Ty = Type::getInt32Ty(Ctx);
Function *NewFDecl = GenXIntrinsic::getGenXDeclaration(
EEI->getModule(), ID, {FixedVectorType::get(I32Ty, 3)});
Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, EEI->getName() + ".esimd", EEI);
int ExtractIndex = getIndexForSuffix(Suff);
assert(ExtractIndex != -1 && "Extract index is invalid.");
Twine ExtractName = ValueName + Suff;

Instruction *ExtrI = ExtractElementInst::Create(
IntrI, ConstantInt::get(I32Ty, ExtractIndex), ExtractName, EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, ExtrI);
if (EEI->getDebugLoc()) {
IntrI->setDebugLoc(EEI->getDebugLoc());
ExtrI->setDebugLoc(EEI->getDebugLoc());
// It's OK if ExtrI and CastI is the same instruction
CastI->setDebugLoc(EEI->getDebugLoc());
/// Generates the call of GenX intrinsic \p IntrinName and inserts it
/// right before the given extract element instruction \p EEI using the result
/// of vector load. The parameter \p IsVectorCall tells what version of GenX
/// intrinsic (scalar or vector) to use to lower the load from SPIRV global.
static Instruction *generateGenXCall(ExtractElementInst *EEI,
StringRef IntrinName, bool IsVectorCall) {
uint64_t IndexValue = getIndexFromExtract(EEI);
std::string Suffix =
IsVectorCall
? ".v3i32"
: (Twine(".") + Twine(static_cast<char>('x' + IndexValue))).str();
std::string FullIntrinName = (Twine(GenXIntrinsic::getGenXIntrinsicPrefix()) +
Twine(IntrinName) + Suffix)
.str();
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(FullIntrinName);
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
Function *NewFDecl =
IsVectorCall
? GenXIntrinsic::getGenXDeclaration(
EEI->getModule(), ID, FixedVectorType::get(I32Ty, MAX_DIMS))
: GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID);

std::string ResultName =
(Twine(EEI->getNameOrAsOperand()) + "." + FullIntrinName).str();
Instruction *Inst = IntrinsicInst::Create(NewFDecl, {}, ResultName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());

if (IsVectorCall) {
Type *I32Ty = Type::getInt32Ty(EEI->getModule()->getContext());
std::string ExtractName =
(Twine(Inst->getNameOrAsOperand()) + ".ext." + Twine(IndexValue)).str();
Inst = ExtractElementInst::Create(Inst, ConstantInt::get(I32Ty, IndexValue),
ExtractName, EEI);
Inst->setDebugLoc(EEI->getDebugLoc());
}
return CastI;
Inst = addCastInstIfNeeded(EEI, Inst);
return Inst;
}

// Helper function to convert extractelement instruction associated with the
// load from SPIRV builtin global, into the GenX intrinsic. It also generates
// required cast instructions. Example:
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %1 = extractelement <3 x i64> %0, i64 0
// =>
// %0 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64>
// addrspace(1)* @__spirv_BuiltInWorkgroupId to <3 x i64> addrspace(4)*), align
// 32 %group.id.x = call i32 @llvm.genx.group.id.x() %group.id.x.cast.ty = zext
// i32 %group.id.x to i64
static Instruction *generateGenXForSpirv(ExtractElementInst *EEI,
StringRef Suff,
const std::string &IntrinName) {
std::string IntrName = std::string(GenXIntrinsic::getGenXIntrinsicPrefix()) +
IntrinName + Suff.str();
auto ID = GenXIntrinsic::lookupGenXIntrinsicID(IntrName);
Function *NewFDecl =
GenXIntrinsic::getGenXDeclaration(EEI->getModule(), ID, {});

Instruction *IntrI =
IntrinsicInst::Create(NewFDecl, {}, IntrinName + Suff.str(), EEI);
Instruction *CastI = addCastInstIfNeeded(EEI, IntrI);
if (EEI->getDebugLoc()) {
IntrI->setDebugLoc(EEI->getDebugLoc());
// It's OK if IntrI and CastI is the same instruction
CastI->setDebugLoc(EEI->getDebugLoc());
/// Replaces the load \p LI of SPIRV global with corresponding call(s) of GenX
/// intrinsic(s). The users of \p LI may also be transformed if needed for
/// def/use type correctness.
/// The replaced instructions are stored into the given container
/// \p InstsToErase.
static void
translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
SmallVectorImpl<Instruction *> &InstsToErase) {
// TODO: Implement support for the following intrinsics:
// uint32_t __spirv_BuiltIn NumSubgroups;
// uint32_t __spirv_BuiltIn SubgroupId;

// Translate those loads from _scalar_ SPIRV globals that can be replaced with
// a const value here.
// The loads from other scalar SPIRV globals may require insertion of GenX
// calls before each user, which is done in the loop by users of 'LI' below.
Value *NewInst = nullptr;
if (SpirvGlobalName == "SubgroupLocalInvocationId") {
NewInst = llvm::Constant::getNullValue(LI->getType());
} else if (SpirvGlobalName == "SubgroupSize" ||
SpirvGlobalName == "SubgroupMaxSize") {
NewInst = llvm::Constant::getIntegerValue(LI->getType(),
llvm::APInt(32, 1, true));
}
if (NewInst) {
LI->replaceAllUsesWith(NewInst);
InstsToErase.push_back(LI);
return;
}
return CastI;
}

// This function translates one occurence of SPIRV builtin use into GenX
// intrinsic.
static Value *translateSpirvGlobalUse(ExtractElementInst *EEI,
StringRef SpirvGlobalName) {
Value *IndexV = EEI->getIndexOperand();
assert(isa<ConstantInt>(IndexV) &&
"Extract element index should be a constant");
// Only loads from _vector_ SPIRV globals reach here now. Their users are
// expected to be ExtractElementInst only, and they are replaced in this loop.
// When loads from _scalar_ SPIRV globals are handled here as well, the users
// will not be replaced by new instructions, but the GenX call replacing the
// original load 'LI' should be inserted before each user.
for (User *LU : LI->users()) {
ExtractElementInst *EEI = cast<ExtractElementInst>(LU);
NewInst = nullptr;

if (SpirvGlobalName == "WorkgroupSize") {
NewInst = generateGenXCall(EEI, "local.size", true);
} else if (SpirvGlobalName == "LocalInvocationId") {
NewInst = generateGenXCall(EEI, "local.id", true);
} else if (SpirvGlobalName == "WorkgroupId") {
NewInst = generateGenXCall(EEI, "group.id", false);
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI = generateGenXCall(EEI, "local.id", true);
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *GroupIdI = generateGenXCall(EEI, "group.id", false);
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
NewInst = BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI = generateGenXCall(EEI, "local.size", true);
Instruction *NumWGI = generateGenXCall(EEI, "group.count", true);
NewInst = BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
// Currently all users of load of GlobalOffset are replaced with 0.
NewInst = llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
NewInst = generateGenXCall(EEI, "group.count", true);
}

// Get the suffix based on the index of extractelement instruction
ConstantInt *IndexC = cast<ConstantInt>(IndexV);
std::string Suff;
if (IndexC->equalsInt(0))
Suff = 'x';
else if (IndexC->equalsInt(1))
Suff = 'y';
else if (IndexC->equalsInt(2))
Suff = 'z';
else
assert(false && "Extract element index should be either 0, 1, or 2");

// Translate SPIRV into GenX intrinsic.
if (SpirvGlobalName == "WorkgroupSize") {
return generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
} else if (SpirvGlobalName == "LocalInvocationId") {
return generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
} else if (SpirvGlobalName == "WorkgroupId") {
return generateGenXForSpirv(EEI, Suff, "group.id.");
} else if (SpirvGlobalName == "GlobalInvocationId") {
// GlobalId = LocalId + WorkGroupSize * GroupId
Instruction *LocalIdI =
generateVectorGenXForSpirv(EEI, Suff, "local.id.v3i32", "local_id.");
Instruction *WGSizeI =
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *GroupIdI = generateGenXForSpirv(EEI, Suff, "group.id.");
Instruction *MulI =
BinaryOperator::CreateMul(WGSizeI, GroupIdI, "mul", EEI);
return BinaryOperator::CreateAdd(LocalIdI, MulI, "add", EEI);
} else if (SpirvGlobalName == "GlobalSize") {
// GlobalSize = WorkGroupSize * NumWorkGroups
Instruction *WGSizeI =
generateVectorGenXForSpirv(EEI, Suff, "local.size.v3i32", "wgsize.");
Instruction *NumWGI = generateVectorGenXForSpirv(
EEI, Suff, "group.count.v3i32", "group_count.");
return BinaryOperator::CreateMul(WGSizeI, NumWGI, "mul", EEI);
} else if (SpirvGlobalName == "GlobalOffset") {
// TODO: Support GlobalOffset SPIRV intrinsics
return llvm::Constant::getNullValue(EEI->getType());
} else if (SpirvGlobalName == "NumWorkgroups") {
return generateVectorGenXForSpirv(EEI, Suff, "group.count.v3i32",
"group_count.");
assert(NewInst && "Load from global SPIRV builtin was not translated");
EEI->replaceAllUsesWith(NewInst);
InstsToErase.push_back(EEI);
}

return nullptr;
InstsToErase.push_back(LI);
}

static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
Expand Down Expand Up @@ -1370,8 +1358,7 @@ SmallPtrSet<Type *, 4> collectGenXVolatileTypes(Module &M) {

} // namespace

PreservedAnalyses SYCLLowerESIMDPass::run(Module &M,
ModuleAnalysisManager &) {
PreservedAnalyses SYCLLowerESIMDPass::run(Module &M, ModuleAnalysisManager &) {
generateKernelMetadata(M);
SmallPtrSet<Type *, 4> GVTS = collectGenXVolatileTypes(M);

Expand Down Expand Up @@ -1507,23 +1494,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,

auto PrefLen = StringRef(SPIRV_INTRIN_PREF).size();

// Go through all the uses of the load instruction from SPIRV builtin
// globals, which are required to be extractelement instructions.
// Translate each of them.
for (auto *LU : LI->users()) {
auto *EEI = dyn_cast<ExtractElementInst>(LU);
assert(EEI && "User of load from global SPIRV builtin is not an "
"extractelement instruction");
Value *TranslatedVal = translateSpirvGlobalUse(
EEI, SpirvGlobal->getName().drop_front(PrefLen));
assert(TranslatedVal &&
"Load from global SPIRV builtin was not translated");
EEI->replaceAllUsesWith(TranslatedVal);
ESIMDToErases.push_back(EEI);
}
// After all users of load were translated, we get rid of the load
// itself.
ESIMDToErases.push_back(LI);
// Translate all uses of the load instruction from SPIRV builtin global.
// Replaces the original global load and it is uses and stores the old
// instructions to ESIMDToErases.
translateSpirvGlobalUses(LI, SpirvGlobal->getName().drop_front(PrefLen),
ESIMDToErases);
}
}
// Now demangle and translate found ESIMD intrinsic calls
Expand Down
30 changes: 30 additions & 0 deletions sycl/test/esimd/spirv_intrins_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ size_t caller() {

size_t DoNotOpt;
cl::sycl::buffer<size_t, 1> buf(&DoNotOpt, 1);
uint32_t DoNotOpt32;
cl::sycl::buffer<uint32_t, 1> buf32(&DoNotOpt32, 1);

size_t DoNotOptXYZ[3];
cl::sycl::buffer<size_t, 1> bufXYZ(&DoNotOptXYZ[0], sycl::range<1>(3));

cl::sycl::queue().submit([&](cl::sycl::handler &cgh) {
auto DoNotOptimize = buf.get_access<cl::sycl::access::mode::write>(cgh);
auto DoNotOptimize32 = buf32.get_access<cl::sycl::access::mode::write>(cgh);

kernel<class kernel_GlobalInvocationId_x>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_GlobalInvocationId_x();
Expand Down Expand Up @@ -213,6 +216,33 @@ size_t caller() {
// CHECK: {{.*}} call i32 @llvm.genx.group.id.x()
// CHECK: {{.*}} call i32 @llvm.genx.group.id.y()
// CHECK: {{.*}} call i32 @llvm.genx.group.id.z()

kernel<class kernel_SubgroupLocalInvocationId>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupLocalInvocationId();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupLocalInvocationId() + 3;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupLocalInvocationId
// CHECK: [[ZEXT0:%.*]] = zext i32 0 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 0, 3

kernel<class kernel_SubgroupSize>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupSize();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupSize() + 7;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupSize
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 1, 7

kernel<class kernel_SubgroupMaxSize>([=]() SYCL_ESIMD_KERNEL {
*DoNotOptimize.get_pointer() = __spirv_SubgroupMaxSize();
*DoNotOptimize32.get_pointer() = __spirv_SubgroupMaxSize() + 9;
});
// CHECK-LABEL: @{{.*}}kernel_SubgroupMaxSize
// CHECK: [[ZEXT0:%.*]] = zext i32 1 to i64
// CHECK: store i64 [[ZEXT0]]
// CHECK: add i32 1, 9
});
return DoNotOpt;
}