Skip to content

Commit 7a1fa6c

Browse files
[SYCL] Remove extra map lookup for eliminated kernel arguments (#8958)
Retreive kernel argument mask while creating the kernel and bundle it together with the cached PiKernel or in the created sycl::kernel object. This removes an extra map lookup during enqueue of cached kernels.
1 parent 8b0b210 commit 7a1fa6c

15 files changed

+201
-160
lines changed

sycl/source/detail/context_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ std::optional<RT::PiProgram> context_impl::getProgramForDeviceGlobal(
423423
}
424424
if (!BuildRes)
425425
return std::nullopt;
426-
return MKernelProgramCache.waitUntilBuilt<compile_program_error>(BuildRes);
426+
return *MKernelProgramCache.waitUntilBuilt<compile_program_error>(BuildRes);
427427
}
428428

429429
} // namespace detail

sycl/source/detail/jit_compiler.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
578578
}
579579
const RTDeviceBinaryImage *DeviceImage = nullptr;
580580
RT::PiProgram Program = nullptr;
581+
const KernelArgMask *EliminatedArgs = nullptr;
581582
if (KernelCG->getKernelBundle() != nullptr) {
582583
// Retrieve the device image from the kernel bundle.
583584
auto KernelBundle = KernelCG->getKernelBundle();
@@ -589,10 +590,12 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
589590

590591
DeviceImage = SyclKernel->getDeviceImage()->get_bin_image_ref();
591592
Program = SyclKernel->getDeviceImage()->get_program_ref();
593+
EliminatedArgs = SyclKernel->getKernelArgMask();
592594
} else if (KernelCG->MSyclKernel != nullptr) {
593595
DeviceImage =
594596
KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref();
595597
Program = KernelCG->MSyclKernel->getDeviceImage()->get_program_ref();
598+
EliminatedArgs = KernelCG->MSyclKernel->getKernelArgMask();
596599
} else {
597600
auto ContextImpl = Queue->getContextImplPtr();
598601
auto Context = detail::createSyclObjFromImpl<context>(ContextImpl);
@@ -602,18 +605,14 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
602605
KernelCG->MOSModuleHandle, KernelName, Context, Device);
603606
Program = detail::ProgramManager::getInstance().createPIProgram(
604607
*DeviceImage, Context, Device);
608+
EliminatedArgs =
609+
detail::ProgramManager::getInstance().getEliminatedKernelArgMask(
610+
KernelCG->MOSModuleHandle, Program, KernelName);
605611
}
606612
if (!DeviceImage || !Program) {
607613
printPerformanceWarning("No suitable IR available for fusion");
608614
return nullptr;
609615
}
610-
ProgramManager::KernelArgMask EliminatedArgs;
611-
if (Program && (KernelCG->MSyclKernel == nullptr ||
612-
!KernelCG->MSyclKernel->isCreatedFromSource())) {
613-
EliminatedArgs =
614-
detail::ProgramManager::getInstance().getEliminatedKernelArgMask(
615-
KernelCG->MOSModuleHandle, Program, KernelName);
616-
}
617616

618617
// Collect information about the arguments of this kernel.
619618

@@ -634,7 +633,8 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
634633
// DPC++ internally uses 'true' to indicate that an argument has been
635634
// eliminated, while the JIT compiler uses 'true' to indicate an
636635
// argument is used. Translate this here.
637-
bool Eliminated = !EliminatedArgs.empty() && EliminatedArgs[ArgIndex++];
636+
bool Eliminated = EliminatedArgs && !EliminatedArgs->empty() &&
637+
(*EliminatedArgs)[ArgIndex++];
638638
ArgDescriptor.UsageMask.emplace_back(!Eliminated);
639639

640640
// If the argument has not been eliminated, i.e., is still present on
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//==----------- kernel_arg_mask.hpp - SYCL KernelArgMask -------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
namespace sycl {
12+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
13+
namespace detail {
14+
using KernelArgMask = std::vector<bool>;
15+
inline KernelArgMask createKernelArgMask(const ByteArray &Bytes) {
16+
const int NBytesForSize = 8;
17+
const int NBitsInElement = 8;
18+
std::uint64_t SizeInBits = 0;
19+
20+
KernelArgMask Result;
21+
for (int I = 0; I < NBytesForSize; ++I)
22+
SizeInBits |= static_cast<std::uint64_t>(Bytes[I]) << I * NBitsInElement;
23+
24+
Result.reserve(SizeInBits);
25+
for (std::uint64_t I = 0; I < SizeInBits; ++I) {
26+
std::uint8_t Byte = Bytes[NBytesForSize + (I / NBitsInElement)];
27+
Result.push_back(Byte & (1 << (I % NBitsInElement)));
28+
}
29+
return Result;
30+
}
31+
} // namespace detail
32+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
33+
} // namespace sycl

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,15 @@ class kernel_bundle_impl {
367367
detail::getSyclObjImpl(*It);
368368

369369
RT::PiKernel Kernel = nullptr;
370-
std::tie(Kernel, std::ignore) =
370+
const KernelArgMask *ArgMask = nullptr;
371+
std::tie(Kernel, std::ignore, ArgMask) =
371372
detail::ProgramManager::getInstance().getOrCreateKernel(
372373
MContext, KernelID.get_name(), /*PropList=*/{},
373374
DeviceImageImpl->get_program_ref());
374375

375-
std::shared_ptr<kernel_impl> KernelImpl = std::make_shared<kernel_impl>(
376-
Kernel, detail::getSyclObjImpl(MContext), DeviceImageImpl, Self);
376+
std::shared_ptr<kernel_impl> KernelImpl =
377+
std::make_shared<kernel_impl>(Kernel, detail::getSyclObjImpl(MContext),
378+
DeviceImageImpl, Self, ArgMask);
377379

378380
return detail::createSyclObjFromImpl<kernel>(KernelImpl);
379381
}

sycl/source/detail/kernel_impl.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ __SYCL_INLINE_VER_NAMESPACE(_V1) {
1818
namespace detail {
1919

2020
kernel_impl::kernel_impl(RT::PiKernel Kernel, ContextImplPtr Context,
21-
KernelBundleImplPtr KernelBundleImpl)
21+
KernelBundleImplPtr KernelBundleImpl,
22+
const KernelArgMask *ArgMask)
2223
: kernel_impl(Kernel, Context,
2324
std::make_shared<program_impl>(Context, Kernel),
24-
/*IsCreatedFromSource*/ true, KernelBundleImpl) {
25+
/*IsCreatedFromSource*/ true, KernelBundleImpl, ArgMask) {
2526
// Enable USM indirect access for interoperability kernels.
2627
// Some PI Plugins (like OpenCL) require this call to enable USM
2728
// For others, PI will turn this into a NOP.
@@ -34,11 +35,13 @@ kernel_impl::kernel_impl(RT::PiKernel Kernel, ContextImplPtr Context,
3435

3536
kernel_impl::kernel_impl(RT::PiKernel Kernel, ContextImplPtr ContextImpl,
3637
ProgramImplPtr ProgramImpl, bool IsCreatedFromSource,
37-
KernelBundleImplPtr KernelBundleImpl)
38+
KernelBundleImplPtr KernelBundleImpl,
39+
const KernelArgMask *ArgMask)
3840
: MKernel(Kernel), MContext(ContextImpl),
3941
MProgramImpl(std::move(ProgramImpl)),
4042
MCreatedFromSource(IsCreatedFromSource),
41-
MKernelBundleImpl(std::move(KernelBundleImpl)) {
43+
MKernelBundleImpl(std::move(KernelBundleImpl)),
44+
MKernelArgMaskPtr{ArgMask} {
4245

4346
RT::PiContext Context = nullptr;
4447
// Using the plugin from the passed ContextImpl
@@ -54,10 +57,12 @@ kernel_impl::kernel_impl(RT::PiKernel Kernel, ContextImplPtr ContextImpl,
5457

5558
kernel_impl::kernel_impl(RT::PiKernel Kernel, ContextImplPtr ContextImpl,
5659
DeviceImageImplPtr DeviceImageImpl,
57-
KernelBundleImplPtr KernelBundleImpl)
60+
KernelBundleImplPtr KernelBundleImpl,
61+
const KernelArgMask *ArgMask)
5862
: MKernel(Kernel), MContext(std::move(ContextImpl)), MProgramImpl(nullptr),
5963
MCreatedFromSource(false), MDeviceImageImpl(std::move(DeviceImageImpl)),
60-
MKernelBundleImpl(std::move(KernelBundleImpl)) {
64+
MKernelBundleImpl(std::move(KernelBundleImpl)),
65+
MKernelArgMaskPtr{ArgMask} {
6166

6267
// kernel_impl shared ownership of kernel handle
6368
if (!is_host()) {

sycl/source/detail/kernel_impl.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <detail/context_impl.hpp>
1212
#include <detail/device_impl.hpp>
13+
#include <detail/kernel_arg_mask.hpp>
1314
#include <detail/kernel_info.hpp>
1415
#include <sycl/detail/common.hpp>
1516
#include <sycl/detail/pi.h>
@@ -42,7 +43,8 @@ class kernel_impl {
4243
/// \param Context is a valid SYCL context
4344
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
4445
kernel_impl(RT::PiKernel Kernel, ContextImplPtr Context,
45-
KernelBundleImplPtr KernelBundleImpl);
46+
KernelBundleImplPtr KernelBundleImpl,
47+
const KernelArgMask *ArgMask = nullptr);
4648

4749
/// Constructs a SYCL kernel instance from a SYCL program and a PiKernel
4850
///
@@ -59,7 +61,8 @@ class kernel_impl {
5961
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
6062
kernel_impl(RT::PiKernel Kernel, ContextImplPtr ContextImpl,
6163
ProgramImplPtr ProgramImpl, bool IsCreatedFromSource,
62-
KernelBundleImplPtr KernelBundleImpl);
64+
KernelBundleImplPtr KernelBundleImpl,
65+
const KernelArgMask *ArgMask);
6366

6467
/// Constructs a SYCL kernel_impl instance from a SYCL device_image,
6568
/// kernel_bundle and / PiKernel.
@@ -69,7 +72,8 @@ class kernel_impl {
6972
/// \param KernelBundleImpl is a valid instance of kernel_bundle_impl
7073
kernel_impl(RT::PiKernel Kernel, ContextImplPtr ContextImpl,
7174
DeviceImageImplPtr DeviceImageImpl,
72-
KernelBundleImplPtr KernelBundleImpl);
75+
KernelBundleImplPtr KernelBundleImpl,
76+
const KernelArgMask *ArgMask);
7377

7478
/// Constructs a SYCL kernel for host device
7579
///
@@ -177,6 +181,8 @@ class kernel_impl {
177181
return MNoncacheableEnqueueMutex;
178182
}
179183

184+
const KernelArgMask *getKernelArgMask() const { return MKernelArgMaskPtr; }
185+
180186
private:
181187
RT::PiKernel MKernel;
182188
const ContextImplPtr MContext;
@@ -186,6 +192,7 @@ class kernel_impl {
186192
const KernelBundleImplPtr MKernelBundleImpl;
187193
bool MIsInterop = false;
188194
std::mutex MNoncacheableEnqueueMutex;
195+
const KernelArgMask *MKernelArgMaskPtr;
189196

190197
bool isBuiltInKernel(const device &Device) const;
191198
void checkIfValidForNumArgsInfoQuery() const;

sycl/source/detail/kernel_program_cache.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,28 @@ namespace detail {
1616
KernelProgramCache::~KernelProgramCache() {
1717
for (auto &ProgIt : MCachedPrograms.Cache) {
1818
ProgramWithBuildStateT &ProgWithState = ProgIt.second;
19-
PiProgramT *ToBeDeleted = ProgWithState.Ptr.load();
19+
RT::PiProgram *ToBeDeleted = ProgWithState.Ptr.load();
2020

2121
if (!ToBeDeleted)
2222
continue;
2323

24-
auto KernIt = MKernelsPerProgramCache.find(ToBeDeleted);
24+
auto KernIt = MKernelsPerProgramCache.find(*ToBeDeleted);
2525

2626
if (KernIt != MKernelsPerProgramCache.end()) {
2727
for (auto &p : KernIt->second) {
28-
KernelWithBuildStateT &KernelWithState = p.second;
29-
PiKernelT *Kern = KernelWithState.Ptr.load();
28+
BuildResult<KernelArgMaskPairT> &KernelWithState = p.second;
29+
KernelArgMaskPairT *KernelArgMaskPair = KernelWithState.Ptr.load();
3030

31-
if (Kern) {
31+
if (KernelArgMaskPair) {
3232
const detail::plugin &Plugin = MParentContext->getPlugin();
33-
Plugin.call<PiApiKind::piKernelRelease>(Kern);
33+
Plugin.call<PiApiKind::piKernelRelease>(KernelArgMaskPair->first);
3434
}
3535
}
3636
MKernelsPerProgramCache.erase(KernIt);
3737
}
3838

3939
const detail::plugin &Plugin = MParentContext->getPlugin();
40-
Plugin.call<PiApiKind::piProgramRelease>(ToBeDeleted);
40+
Plugin.call<PiApiKind::piProgramRelease>(*ToBeDeleted);
4141
}
4242
}
4343
} // namespace detail

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class KernelProgramCache {
4848
/// Currently there is only a single user - ProgramManager class.
4949
template <typename T> struct BuildResult {
5050
std::atomic<T *> Ptr;
51+
T Val;
5152
std::atomic<BuildState> State;
5253
BuildError Error;
5354

@@ -68,9 +69,7 @@ class KernelProgramCache {
6869
BuildResult(T *P, BuildState S) : Ptr{P}, State{S}, Error{"", 0} {}
6970
};
7071

71-
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
72-
using PiProgramPtrT = std::atomic<PiProgramT *>;
73-
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
72+
using ProgramWithBuildStateT = BuildResult<RT::PiProgram>;
7473
using ProgramCacheKeyT = std::pair<std::pair<SerializedObj, std::uintptr_t>,
7574
std::pair<RT::PiDevice, std::string>>;
7675
using CommonProgramKeyT = std::pair<std::uintptr_t, RT::PiDevice>;
@@ -84,18 +83,15 @@ class KernelProgramCache {
8483

8584
using ContextPtr = context_impl *;
8685

87-
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
88-
89-
using PiKernelPtrT = std::atomic<PiKernelT *>;
90-
using KernelWithBuildStateT = BuildResult<PiKernelT>;
91-
using KernelByNameT = std::map<std::string, KernelWithBuildStateT>;
86+
using KernelArgMaskPairT = std::pair<RT::PiKernel, const KernelArgMask *>;
87+
using KernelByNameT = std::map<std::string, BuildResult<KernelArgMaskPairT>>;
9288
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;
9389

9490
using KernelFastCacheKeyT =
9591
std::tuple<SerializedObj, OSModuleHandle, RT::PiDevice, std::string,
9692
std::string>;
97-
using KernelFastCacheValT =
98-
std::tuple<RT::PiKernel, std::mutex *, RT::PiProgram>;
93+
using KernelFastCacheValT = std::tuple<RT::PiKernel, std::mutex *,
94+
const KernelArgMask *, RT::PiProgram>;
9995
using KernelFastCacheT = std::map<KernelFastCacheKeyT, KernelFastCacheValT>;
10096

10197
~KernelProgramCache();
@@ -128,7 +124,7 @@ class KernelProgramCache {
128124
return std::make_pair(&Inserted.first->second, Inserted.second);
129125
}
130126

131-
std::pair<KernelWithBuildStateT *, bool>
127+
std::pair<BuildResult<KernelArgMaskPairT> *, bool>
132128
getOrInsertKernel(RT::PiProgram Program, const std::string &KernelName) {
133129
auto LockedCache = acquireKernelsPerProgramCache();
134130
auto &Cache = LockedCache.get()[Program];
@@ -173,7 +169,7 @@ class KernelProgramCache {
173169
if (It != MKernelFastCache.end()) {
174170
return It->second;
175171
}
176-
return std::make_tuple(nullptr, nullptr, nullptr);
172+
return std::make_tuple(nullptr, nullptr, nullptr, nullptr);
177173
}
178174

179175
template <typename KeyT, typename ValT>

sycl/source/detail/program_impl.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,9 @@ kernel program_impl::get_kernel(std::string KernelName,
351351
return createSyclObjFromImpl<kernel>(
352352
std::make_shared<kernel_impl>(MContext, PtrToSelf));
353353
}
354-
return createSyclObjFromImpl<kernel>(
355-
std::make_shared<kernel_impl>(get_pi_kernel(KernelName), MContext,
356-
PtrToSelf, IsCreatedFromSource, nullptr));
354+
auto [Kernel, ArgMask] = get_pi_kernel_arg_mask_pair(KernelName);
355+
return createSyclObjFromImpl<kernel>(std::make_shared<kernel_impl>(
356+
Kernel, MContext, PtrToSelf, IsCreatedFromSource, nullptr, ArgMask));
357357
}
358358

359359
std::vector<std::vector<char>> program_impl::get_binaries() const {
@@ -447,19 +447,20 @@ std::vector<RT::PiDevice> program_impl::get_pi_devices() const {
447447
return PiDevices;
448448
}
449449

450-
RT::PiKernel program_impl::get_pi_kernel(const std::string &KernelName) const {
451-
RT::PiKernel Kernel = nullptr;
450+
std::pair<RT::PiKernel, const KernelArgMask *>
451+
program_impl::get_pi_kernel_arg_mask_pair(const std::string &KernelName) const {
452+
std::pair<RT::PiKernel, const KernelArgMask *> Result;
452453

453454
if (is_cacheable()) {
454-
std::tie(Kernel, std::ignore, std::ignore) =
455+
std::tie(Result.first, std::ignore, Result.second, std::ignore) =
455456
ProgramManager::getInstance().getOrCreateKernel(
456457
MProgramModuleHandle, detail::getSyclObjImpl(get_context()),
457458
detail::getSyclObjImpl(get_devices()[0]), KernelName, this);
458-
getPlugin().call<PiApiKind::piKernelRetain>(Kernel);
459+
getPlugin().call<PiApiKind::piKernelRetain>(Result.first);
459460
} else {
460461
const detail::plugin &Plugin = getPlugin();
461462
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piKernelCreate>(
462-
MProgram, KernelName.c_str(), &Kernel);
463+
MProgram, KernelName.c_str(), &Result.first);
463464
if (Err == PI_ERROR_INVALID_KERNEL_NAME) {
464465
throw invalid_object_error(
465466
"This instance of program does not contain the kernel requested",
@@ -469,11 +470,11 @@ RT::PiKernel program_impl::get_pi_kernel(const std::string &KernelName) const {
469470

470471
// Some PI Plugins (like OpenCL) require this call to enable USM
471472
// For others, PI will turn this into a NOP.
472-
Plugin.call<PiApiKind::piKernelSetExecInfo>(Kernel, PI_USM_INDIRECT_ACCESS,
473-
sizeof(pi_bool), &PI_TRUE);
473+
Plugin.call<PiApiKind::piKernelSetExecInfo>(
474+
Result.first, PI_USM_INDIRECT_ACCESS, sizeof(pi_bool), &PI_TRUE);
474475
}
475476

476-
return Kernel;
477+
return Result;
477478
}
478479

479480
std::vector<device>

sycl/source/detail/program_impl.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ class program_impl {
405405
/// \param KernelName is a string containing PI kernel name.
406406
/// \return an instance of PI kernel with specific name. If kernel is
407407
/// unavailable, an invalid_object_error exception is thrown.
408-
RT::PiKernel get_pi_kernel(const std::string &KernelName) const;
408+
std::pair<RT::PiKernel, const KernelArgMask *>
409+
get_pi_kernel_arg_mask_pair(const std::string &KernelName) const;
409410

410411
/// \return a vector of sorted in ascending order SYCL devices.
411412
std::vector<device> sort_devices_by_cl_device_id(std::vector<device> Devices);

0 commit comments

Comments
 (0)