Skip to content

Commit bc063ac

Browse files
authored
[ESIMD] Enable support for scalar bfloat16 constructor/conversion used in kernels (#8892)
1 parent 74b3d45 commit bc063ac

File tree

3 files changed

+141
-7
lines changed

3 files changed

+141
-7
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,10 @@ class ESIMDIntrinDescTable {
672672
{"slm_init", {"slm.init", {a(0)}}},
673673
{"bf_cvt", {"bf.cvt", {a(0)}}},
674674
{"tf32_cvt", {"tf32.cvt", {a(0)}}},
675+
{"__devicelib_ConvertFToBF16INTEL",
676+
{"__spirv_ConvertFToBF16INTEL", {a(0)}}},
677+
{"__devicelib_ConvertBF16ToFINTEL",
678+
{"__spirv_ConvertBF16ToFINTEL", {a(0)}}},
675679
{"addc", {"addc", {l(0)}}},
676680
{"subb", {"subb", {l(0)}}},
677681
{"bfn", {"bfn", {a(0), a(1), a(2), t(0)}}}};
@@ -703,6 +707,28 @@ static const ESIMDIntrinDesc &getIntrinDesc(StringRef SrcSpelling) {
703707
return It->second;
704708
}
705709

710+
static bool isDevicelibFunction(StringRef FunctionName) {
711+
return llvm::StringSwitch<bool>(FunctionName)
712+
.Case("__devicelib_ConvertFToBF16INTEL", true)
713+
.Case("__devicelib_ConvertBF16ToFINTEL", true)
714+
.Default(false);
715+
}
716+
717+
// Mangle deviceLib function to make it pass through the regular workflow
718+
// These functions are defined as extern "C" which Demangler that is used
719+
// fails to handle properly.
720+
static std::string mangleDevicelibFunction(StringRef FunctionName) {
721+
if (isDevicelibFunction(FunctionName)) {
722+
if (FunctionName.startswith("__devicelib_ConvertFToBF16INTEL")) {
723+
return (Twine("_Z31") + FunctionName + "RKf").str();
724+
}
725+
if (FunctionName.startswith("__devicelib_ConvertBF16ToFINTEL")) {
726+
return (Twine("_Z31") + FunctionName + "RKt").str();
727+
}
728+
}
729+
return FunctionName.str();
730+
}
731+
706732
Type *parsePrimitiveTypeString(StringRef TyStr, LLVMContext &Ctx) {
707733
return llvm::StringSwitch<Type *>(TyStr)
708734
.Case("bool", IntegerType::getInt1Ty(Ctx))
@@ -1326,6 +1352,46 @@ static void createESIMDIntrinsicArgs(const ESIMDIntrinDesc &Desc,
13261352
}
13271353
}
13281354

1355+
// Create a spirv function declaration
1356+
// This is used for lowering devicelib functions.
1357+
// The function
1358+
// 1. Generates spirv function definition
1359+
// 2. Converts passed by reference argument of devicelib function into passed by
1360+
// value argument of spirv functions
1361+
// 3. Assigns proper attributes to generated function
1362+
static Function *
1363+
createDeviceLibESIMDDeclaration(const ESIMDIntrinDesc &Desc,
1364+
SmallVector<Value *, 16> &GenXArgs,
1365+
CallInst &CI) {
1366+
SmallVector<Type *, 16> ArgTypes;
1367+
IRBuilder<> Bld(&CI);
1368+
for (unsigned i = 0; i < GenXArgs.size(); ++i) {
1369+
Type *NTy = llvm::StringSwitch<Type *>(Desc.GenXSpelling)
1370+
.Case("__spirv_ConvertFToBF16INTEL",
1371+
Type::getFloatTy(CI.getContext()))
1372+
.Case("__spirv_ConvertBF16ToFINTEL",
1373+
Type::getInt16Ty(CI.getContext()))
1374+
.Default(nullptr);
1375+
1376+
auto LI = Bld.CreateLoad(NTy, GenXArgs[i]);
1377+
GenXArgs[i] = LI;
1378+
ArgTypes.push_back(NTy);
1379+
}
1380+
auto *FType = FunctionType::get(CI.getType(), ArgTypes, false);
1381+
Function *F = CI.getModule()->getFunction(Desc.GenXSpelling);
1382+
if (!F) {
1383+
F = Function::Create(FType, GlobalVariable::ExternalLinkage,
1384+
Desc.GenXSpelling, CI.getModule());
1385+
F->addFnAttr(Attribute::NoUnwind);
1386+
F->addFnAttr(Attribute::Convergent);
1387+
F->setDSOLocal(true);
1388+
1389+
F->setCallingConv(CallingConv::SPIR_FUNC);
1390+
}
1391+
1392+
return F;
1393+
}
1394+
13291395
// Create a simple function declaration
13301396
// This is used for testing purposes, when it is impossible to query
13311397
// vc-intrinsics
@@ -1403,7 +1469,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
14031469
using Demangler = id::ManglingParser<SimpleAllocator>;
14041470
Function *F = CI.getCalledFunction();
14051471
llvm::esimd::assert_and_diag(F, "function to translate is invalid");
1406-
StringRef MnglName = F->getName();
1472+
std::string MnglNameStr = mangleDevicelibFunction(F->getName());
1473+
StringRef MnglName = MnglNameStr;
1474+
14071475
Demangler Parser(MnglName.begin(), MnglName.end());
14081476
id::Node *AST = Parser.parse();
14091477

@@ -1416,7 +1484,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
14161484
auto *FE = static_cast<id::FunctionEncoding *>(AST);
14171485
id::StringView BaseNameV = FE->getName()->getBaseName();
14181486

1419-
auto PrefLen = StringRef(ESIMD_INTRIN_PREF1).size();
1487+
auto PrefLen = isDevicelibFunction(F->getName())
1488+
? 0
1489+
: StringRef(ESIMD_INTRIN_PREF1).size();
14201490
StringRef BaseName(BaseNameV.begin() + PrefLen, BaseNameV.size() - PrefLen);
14211491
const auto &Desc = getIntrinDesc(BaseName);
14221492
if (!Desc.isValid()) // TODO remove this once all intrinsics are supported
@@ -1429,7 +1499,9 @@ static void translateESIMDIntrinsicCall(CallInst &CI) {
14291499
Function *NewFDecl = nullptr;
14301500
bool DoesFunctionReturnStructure =
14311501
isStructureReturningFunction(Desc.GenXSpelling);
1432-
if (Desc.GenXSpelling.rfind("test.src.", 0) == 0) {
1502+
if (isDevicelibFunction(F->getName())) {
1503+
NewFDecl = createDeviceLibESIMDDeclaration(Desc, GenXArgs, CI);
1504+
} else if (Desc.GenXSpelling.rfind("test.src.", 0) == 0) {
14331505
// Special case for testing purposes
14341506
NewFDecl = createTestESIMDDeclaration(Desc, GenXArgs, CI);
14351507
} else {
@@ -1724,7 +1796,7 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
17241796

17251797
// See if the Name represents an ESIMD intrinsic and demangle only if it
17261798
// does.
1727-
if (!Name.consume_front(ESIMD_INTRIN_PREF0))
1799+
if (!Name.consume_front(ESIMD_INTRIN_PREF0) && !isDevicelibFunction(Name))
17281800
continue;
17291801
// now skip the digits
17301802
Name = Name.drop_while([](char C) { return std::isdigit(C); });
@@ -1771,7 +1843,8 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
17711843
assert(!Name.startswith("__sycl_set_kernel_properties") &&
17721844
"__sycl_set_kernel_properties must have been lowered");
17731845

1774-
if (Name.empty() || !Name.startswith(ESIMD_INTRIN_PREF1))
1846+
if (Name.empty() ||
1847+
(!Name.startswith(ESIMD_INTRIN_PREF1) && !isDevicelibFunction(Name)))
17751848
continue;
17761849
// this is ESIMD intrinsic - record for later translation
17771850
ESIMDIntrCalls.push_back(CI);
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// REQUIRES: gpu
2+
// UNSUPPORTED: gpu-intel-gen9 || cuda || hip
3+
// RUN: %clangxx -fsycl %s -o %t.out
4+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
5+
// XFAIL: gpu && !esimd_emulator
6+
//==- bfloat16Constructor.cpp - Test to verify use of bfloat16 constructor -==//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
// This is basic test to verify use of bfloat16 constructor in kernel.
15+
// TODO: Enable the test once the GPU RT supporting the functionality reaches
16+
// the CI
17+
18+
#include <CL/sycl.hpp>
19+
#include <ext/intel/esimd.hpp>
20+
#include <iostream>
21+
22+
using namespace sycl;
23+
24+
int main() {
25+
constexpr unsigned Size = 32;
26+
constexpr unsigned VL = 32;
27+
constexpr unsigned GroupSize = 1;
28+
29+
queue q;
30+
auto dev = q.get_device();
31+
std::cout << "Running on " << dev.get_info<info::device::name>() << "\n";
32+
auto *C = malloc_shared<float>(Size * sizeof(float), dev, q.get_context());
33+
34+
for (auto i = 0; i != Size; i++) {
35+
C[i] = 7;
36+
}
37+
38+
nd_range<1> Range(range<1>(Size / VL), range<1>(GroupSize));
39+
40+
auto e = q.submit([&](handler &cgh) {
41+
cgh.parallel_for<class Test>(Range, [=](nd_item<1> i) SYCL_ESIMD_KERNEL {
42+
using bf16 = sycl::ext::oneapi::bfloat16;
43+
using namespace __ESIMD_NS;
44+
using namespace __ESIMD_ENS;
45+
simd<bf16, 32> data_bf16 = bf16(0);
46+
simd<float, 32> data = data_bf16;
47+
lsc_block_store<float, 32>(C, data);
48+
});
49+
});
50+
e.wait();
51+
bool Pass = true;
52+
for (auto i = 0; i != Size; i++) {
53+
if (C[i] != 0) {
54+
Pass = false;
55+
}
56+
}
57+
58+
free(C, q);
59+
std::cout << (Pass ? "Test Passed\n" : "Test FAILED\n");
60+
return 0;
61+
}

sycl/test/esimd/fp16_converts.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void bf16_scalar() {
4444
// The actual support in GPU RT is on the way though.
4545
float F32_scalar = 1;
4646
bfloat16 BF16_scalar = F32_scalar;
47-
// CHECK: call spir_func zeroext i16 @__devicelib_ConvertFToBF16INTEL(float {{[^)]+}})
47+
// CHECK: call i16 @__spirv_ConvertFToBF16INTEL(float {{[^)]+}})
4848
float F32_scalar_conv = BF16_scalar;
49-
// CHECK: call spir_func float @__devicelib_ConvertBF16ToFINTEL(i16 {{[^)]+}})
49+
// CHECK: call float @__spirv_ConvertBF16ToFINTEL(i16 {{[^)]+}})
5050
}

0 commit comments

Comments
 (0)