Skip to content

Commit 997b029

Browse files
[SYCL] Filter out unneeded device images with lower state than requested
When fetching device images compatible with non-input states, we can ignore an image if another one with a higher state is available for all the possible kernel-device pairs. This patch adds the logic for filtering out such unnecessary images so that we can avoid JIT compilation if both AOT and SPIRV images are present.
1 parent 7ee2f26 commit 997b029

File tree

6 files changed

+313
-51
lines changed

6 files changed

+313
-51
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 102 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,46 +1683,120 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
16831683
}
16841684
assert(BinImages.size() > 0 && "Expected to find at least one device image");
16851685

1686+
// Ignore images with incompatible state. Image is considered compatible
1687+
// with a target state if an image is already in the target state or can
1688+
// be brought to target state by compiling/linking/building.
1689+
//
1690+
// Example: an image in "executable" state is not compatible with
1691+
// "input" target state - there is no operation to convert the image it
1692+
// to "input" state. An image in "input" state is compatible with
1693+
// "executable" target state because it can be built to get into
1694+
// "executable" state.
1695+
for (auto It = BinImages.begin(); It != BinImages.end();) {
1696+
if (getBinImageState(*It) > TargetState)
1697+
It = BinImages.erase(It);
1698+
else
1699+
++It;
1700+
}
1701+
16861702
std::vector<device_image_plain> SYCLDeviceImages;
1687-
for (RTDeviceBinaryImage *BinImage : BinImages) {
1688-
const bundle_state ImgState = getBinImageState(BinImage);
1689-
1690-
// Ignore images with incompatible state. Image is considered compatible
1691-
// with a target state if an image is already in the target state or can
1692-
// be brought to target state by compiling/linking/building.
1693-
//
1694-
// Example: an image in "executable" state is not compatible with
1695-
// "input" target state - there is no operation to convert the image it
1696-
// to "input" state. An image in "input" state is compatible with
1697-
// "executable" target state because it can be built to get into
1698-
// "executable" state.
1699-
if (ImgState > TargetState)
1700-
continue;
17011703

1702-
for (const sycl::device &Dev : Devs) {
1704+
// If a non-input state is requested, we can filter out some compatible
1705+
// images and return only those with the highest compatible state for each
1706+
// device-kernel pair. This map tracks how many kernel-device pairs need each
1707+
// image, so that any unneeded ones are skipped.
1708+
// TODO this has no effect if the requested state is input, consider having
1709+
// a separate branch for that case to avoid unnecessary tracking work.
1710+
struct DeviceBinaryImageInfo {
1711+
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
1712+
bundle_state State = bundle_state::input;
1713+
int RequirementCounter = 0;
1714+
};
1715+
std::map<RTDeviceBinaryImage *, DeviceBinaryImageInfo> ImageInfoMap;
1716+
1717+
for (const sycl::device &Dev : Devs) {
1718+
// Track the highest image state for each requested kernel.
1719+
using StateImagesPairT =
1720+
std::pair<bundle_state, std::vector<RTDeviceBinaryImage *>>;
1721+
using KernelImageMapT =
1722+
std::map<kernel_id, StateImagesPairT, LessByNameComp>;
1723+
KernelImageMapT KernelImageMap;
1724+
if (!KernelIDs.empty())
1725+
for (const kernel_id &KernelID : KernelIDs)
1726+
KernelImageMap.insert({KernelID, {}});
1727+
1728+
for (RTDeviceBinaryImage *BinImage : BinImages) {
17031729
if (!compatibleWithDevice(BinImage, Dev) ||
17041730
!doesDevSupportDeviceRequirements(Dev, *BinImage))
17051731
continue;
17061732

1707-
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
1708-
// Collect kernel names for the image
1709-
{
1710-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1711-
KernelIDs = m_BinImg2KernelIDs[BinImage];
1712-
// If the image does not contain any non-service kernels we can skip it.
1713-
if (!KernelIDs || KernelIDs->empty())
1714-
continue;
1733+
auto InsertRes = ImageInfoMap.insert({BinImage, {}});
1734+
DeviceBinaryImageInfo &ImgInfo = InsertRes.first->second;
1735+
if (InsertRes.second) {
1736+
ImgInfo.State = getBinImageState(BinImage);
1737+
std::shared_ptr<std::vector<sycl::kernel_id>> ImageKernelIDs;
1738+
// Collect kernel names for the image
1739+
{
1740+
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1741+
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
1742+
}
17151743
}
1744+
const bundle_state ImgState = ImgInfo.State;
1745+
const std::shared_ptr<std::vector<sycl::kernel_id>> &ImageKernelIDs =
1746+
ImgInfo.KernelIDs;
1747+
int &ImgRequirementCounter = ImgInfo.RequirementCounter;
17161748

1717-
DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
1718-
BinImage, Ctx, Devs, ImgState, KernelIDs, /*PIProgram=*/nullptr);
1749+
// If the image does not contain any non-service kernels we can skip it.
1750+
if (!ImageKernelIDs || ImageKernelIDs->empty())
1751+
continue;
17191752

1720-
SYCLDeviceImages.push_back(
1721-
createSyclObjFromImpl<device_image_plain>(Impl));
1722-
break;
1753+
// Update tracked information.
1754+
for (kernel_id &KernelID : *ImageKernelIDs) {
1755+
StateImagesPairT *StateImagesPair;
1756+
// If only specific kernels are requested, ignore the rest.
1757+
if (!KernelIDs.empty()) {
1758+
auto It = KernelImageMap.find(KernelID);
1759+
if (It == KernelImageMap.end())
1760+
continue;
1761+
StateImagesPair = &It->second;
1762+
} else
1763+
StateImagesPair = &KernelImageMap[KernelID];
1764+
1765+
auto &[KernelImagesState, KernelImages] = *StateImagesPair;
1766+
1767+
if (KernelImages.empty()) {
1768+
KernelImagesState = ImgState;
1769+
KernelImages.push_back(BinImage);
1770+
++ImgRequirementCounter;
1771+
} else if (KernelImagesState < ImgState) {
1772+
for (RTDeviceBinaryImage *Img : KernelImages) {
1773+
auto It = ImageInfoMap.find(Img);
1774+
assert(It != ImageInfoMap.end());
1775+
assert(It->second.RequirementCounter > 0);
1776+
--(It->second.RequirementCounter);
1777+
}
1778+
KernelImages.clear();
1779+
KernelImages.push_back(BinImage);
1780+
++ImgRequirementCounter;
1781+
} else if (KernelImagesState == ImgState) {
1782+
KernelImages.push_back(BinImage);
1783+
++ImgRequirementCounter;
1784+
}
1785+
}
17231786
}
17241787
}
17251788

1789+
for (const auto &ImgInfoPair : ImageInfoMap) {
1790+
if (ImgInfoPair.second.RequirementCounter == 0)
1791+
continue;
1792+
1793+
DeviceImageImplPtr Impl = std::make_shared<detail::device_image_impl>(
1794+
ImgInfoPair.first, Ctx, Devs, ImgInfoPair.second.State,
1795+
ImgInfoPair.second.KernelIDs, /*PIProgram=*/nullptr);
1796+
1797+
SYCLDeviceImages.push_back(createSyclObjFromImpl<device_image_plain>(Impl));
1798+
}
1799+
17261800
return SYCLDeviceImages;
17271801
}
17281802

sycl/unittests/SYCL2020/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_sycl_unittest(SYCL2020Tests OBJECT
44
GetNativeOpenCL.cpp
55
SpecializationConstant.cpp
66
KernelBundle.cpp
7+
KernelBundleStateFiltering.cpp
78
KernelID.cpp
89
HasExtension.cpp
910
IsCompatible.cpp
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
//==---- KernelBundleStateFiltering.cpp --- Kernel bundle unit test --------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <detail/device_impl.hpp>
10+
#include <detail/kernel_bundle_impl.hpp>
11+
#include <sycl/sycl.hpp>
12+
13+
#include <helpers/MockKernelInfo.hpp>
14+
#include <helpers/PiImage.hpp>
15+
#include <helpers/PiMock.hpp>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <algorithm>
20+
#include <set>
21+
#include <vector>
22+
23+
class KernelA;
24+
class KernelB;
25+
class KernelC;
26+
class KernelD;
27+
namespace sycl {
28+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
29+
namespace detail {
30+
template <> struct KernelInfo<KernelA> : public unittest::MockKernelInfoBase {
31+
static constexpr const char *getName() { return "KernelA"; }
32+
};
33+
template <> struct KernelInfo<KernelB> : public unittest::MockKernelInfoBase {
34+
static constexpr const char *getName() { return "KernelB"; }
35+
};
36+
template <> struct KernelInfo<KernelC> : public unittest::MockKernelInfoBase {
37+
static constexpr const char *getName() { return "KernelC"; }
38+
};
39+
template <> struct KernelInfo<KernelD> : public unittest::MockKernelInfoBase {
40+
static constexpr const char *getName() { return "KernelD"; }
41+
};
42+
} // namespace detail
43+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
44+
} // namespace sycl
45+
46+
namespace {
47+
48+
std::set<const void *> TrackedImages;
49+
sycl::unittest::PiImage
50+
generateDefaultImage(std::initializer_list<std::string> KernelNames,
51+
pi_device_binary_type BinaryType,
52+
const char *DeviceTargetSpec) {
53+
using namespace sycl::unittest;
54+
55+
PiPropertySet PropSet;
56+
57+
static unsigned char NImage = 0;
58+
std::vector<unsigned char> Bin{NImage++};
59+
60+
PiArray<PiOffloadEntry> Entries = makeEmptyKernels(KernelNames);
61+
62+
PiImage Img{BinaryType, // Format
63+
DeviceTargetSpec,
64+
"", // Compile options
65+
"", // Link options
66+
std::move(Bin),
67+
std::move(Entries),
68+
std::move(PropSet)};
69+
const void *BinaryPtr = Img.getBinaryPtr();
70+
TrackedImages.insert(BinaryPtr);
71+
72+
return Img;
73+
}
74+
75+
// Image 0: input, KernelA KernelB
76+
// Image 1: exe, KernelA
77+
// Image 2: input, KernelC
78+
// Image 3: exe, KernelC
79+
// Image 4: input, KernelD
80+
sycl::unittest::PiImage Imgs[] = {
81+
generateDefaultImage({"KernelA", "KernelB"}, PI_DEVICE_BINARY_TYPE_SPIRV,
82+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
83+
generateDefaultImage({"KernelA"}, PI_DEVICE_BINARY_TYPE_NATIVE,
84+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
85+
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_SPIRV,
86+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64),
87+
generateDefaultImage({"KernelC"}, PI_DEVICE_BINARY_TYPE_NATIVE,
88+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64_X86_64),
89+
generateDefaultImage({"KernelD"}, PI_DEVICE_BINARY_TYPE_SPIRV,
90+
__SYCL_PI_DEVICE_BINARY_TARGET_SPIRV64)};
91+
92+
sycl::unittest::PiImageArray<std::size(Imgs)> ImgArray{Imgs};
93+
std::vector<unsigned char> UsedImageIndices;
94+
95+
void redefinedPiProgramCreateCommon(const void *bin) {
96+
if (TrackedImages.count(bin) != 0) {
97+
unsigned char ImgIdx = *reinterpret_cast<const unsigned char *>(bin);
98+
UsedImageIndices.push_back(ImgIdx);
99+
}
100+
}
101+
102+
pi_result redefinedPiProgramCreate(pi_context context, const void *il,
103+
size_t length, pi_program *res_program) {
104+
redefinedPiProgramCreateCommon(il);
105+
return PI_SUCCESS;
106+
}
107+
108+
pi_result redefinedPiProgramCreateWithBinary(
109+
pi_context context, pi_uint32 num_devices, const pi_device *device_list,
110+
const size_t *lengths, const unsigned char **binaries,
111+
size_t num_metadata_entries, const pi_device_binary_property *metadata,
112+
pi_int32 *binary_status, pi_program *ret_program) {
113+
redefinedPiProgramCreateCommon(binaries[0]);
114+
return PI_SUCCESS;
115+
}
116+
117+
pi_result redefinedDevicesGet(pi_platform platform, pi_device_type device_type,
118+
pi_uint32 num_entries, pi_device *devices,
119+
pi_uint32 *num_devices) {
120+
if (num_devices)
121+
*num_devices = 2;
122+
123+
if (devices) {
124+
devices[0] = reinterpret_cast<pi_device>(1);
125+
devices[1] = reinterpret_cast<pi_device>(2);
126+
}
127+
128+
return PI_SUCCESS;
129+
}
130+
131+
pi_result redefinedExtDeviceSelectBinary(pi_device device,
132+
pi_device_binary *binaries,
133+
pi_uint32 num_binaries,
134+
pi_uint32 *selected_binary_ind) {
135+
EXPECT_EQ(num_binaries, 1U);
136+
// Treat image 3 as incompatible with one of the devices.
137+
if (TrackedImages.count(binaries[0]->BinaryStart) != 0 &&
138+
*binaries[0]->BinaryStart == 3 &&
139+
device == reinterpret_cast<pi_device>(2)) {
140+
return PI_ERROR_INVALID_BINARY;
141+
}
142+
*selected_binary_ind = 0;
143+
return PI_SUCCESS;
144+
}
145+
146+
void verifyImageUse(const std::vector<unsigned char> &ExpectedImages) {
147+
std::sort(UsedImageIndices.begin(), UsedImageIndices.end());
148+
EXPECT_TRUE(std::is_sorted(ExpectedImages.begin(), ExpectedImages.end()));
149+
EXPECT_EQ(UsedImageIndices, ExpectedImages);
150+
UsedImageIndices.clear();
151+
}
152+
153+
TEST(KernelBundle, DeviceImageStateFiltering) {
154+
sycl::unittest::PiMock Mock;
155+
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreate>(
156+
redefinedPiProgramCreate);
157+
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramCreateWithBinary>(
158+
redefinedPiProgramCreateWithBinary);
159+
160+
// No kernel ids specified.
161+
{
162+
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
163+
sycl::context Ctx{Dev};
164+
165+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
166+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(Ctx, {Dev});
167+
verifyImageUse({0, 1, 3, 4});
168+
}
169+
170+
sycl::kernel_id KernelAID = sycl::get_kernel_id<KernelA>();
171+
sycl::kernel_id KernelCID = sycl::get_kernel_id<KernelC>();
172+
sycl::kernel_id KernelDID = sycl::get_kernel_id<KernelD>();
173+
174+
// Request specific kernel ids.
175+
{
176+
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
177+
sycl::context Ctx{Dev};
178+
179+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
180+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
181+
Ctx, {Dev}, {KernelAID, KernelCID, KernelDID});
182+
verifyImageUse({1, 3, 4});
183+
}
184+
185+
// Check the case where some executable images are unsupported by one of
186+
// the devices.
187+
{
188+
Mock.redefine<sycl::detail::PiApiKind::piDevicesGet>(redefinedDevicesGet);
189+
Mock.redefine<sycl::detail::PiApiKind::piextDeviceSelectBinary>(
190+
redefinedExtDeviceSelectBinary);
191+
const std::vector<sycl::device> Devs = Mock.getPlatform().get_devices();
192+
sycl::context Ctx{Devs};
193+
194+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
195+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(
196+
Ctx, Devs, {KernelAID, KernelCID, KernelDID});
197+
verifyImageUse({1, 2, 3, 4});
198+
}
199+
}
200+
} // namespace

sycl/unittests/helpers/PiImage.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class PiImage {
262262
MPropertySet.end(),
263263
};
264264
}
265+
const unsigned char *getBinaryPtr() { return &*MBinary.begin(); }
265266

266267
private:
267268
uint16_t MVersion;

sycl/unittests/helpers/TestKernel.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,18 @@
88

99
#pragma once
1010

11+
#include "MockKernelInfo.hpp"
1112
#include "PiImage.hpp"
1213

1314
template <size_t KernelSize = 1> class TestKernel;
1415

1516
namespace sycl {
1617
__SYCL_INLINE_VER_NAMESPACE(_V1) {
1718
namespace detail {
18-
template <size_t KernelSize> struct KernelInfo<TestKernel<KernelSize>> {
19-
static constexpr unsigned getNumParams() { return 0; }
20-
static const kernel_param_desc_t &getParamDesc(int) {
21-
static kernel_param_desc_t Dummy;
22-
return Dummy;
23-
}
19+
template <size_t KernelSize>
20+
struct KernelInfo<TestKernel<KernelSize>>
21+
: public unittest::MockKernelInfoBase {
2422
static constexpr const char *getName() { return "TestKernel"; }
25-
static constexpr bool isESIMD() { return false; }
26-
static constexpr bool callsThisItem() { return false; }
27-
static constexpr bool callsAnyThisFreeFunction() { return false; }
2823
static constexpr int64_t getKernelSize() { return KernelSize; }
2924
static constexpr const char *getFileName() { return "TestKernel.hpp"; }
3025
static constexpr const char *getFunctionName() {

0 commit comments

Comments
 (0)