Skip to content

Improve get_kernel_bundle performance #5496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ make_kernel_bundle(pi_native_handle NativeHandle, const context &TargetContext,
// this by pre-building the device image and extracting kernel info. We can't
// do the same to user images, since they may contain references to undefined
// symbols (e.g. when kernel_bundle is supposed to be joined with another).
std::vector<kernel_id> KernelIDs{};
auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
auto DevImgImpl = std::make_shared<device_image_impl>(
nullptr, TargetContext, Devices, State, KernelIDs, PiProgram);
device_image_plain DevImg{DevImgImpl};
Expand Down
25 changes: 17 additions & 8 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {

template <class T> struct LessByHash {
bool operator()(const T &LHS, const T &RHS) const {
return getSyclObjImpl(LHS) < getSyclObjImpl(RHS);
}
};

// The class is impl counterpart for sycl::device_image
// It can represent a program in different states, kernel_id's it has and state
// of specialization constants for it
Expand All @@ -51,7 +57,8 @@ class device_image_impl {

device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
std::vector<device> Devices, bundle_state State,
std::vector<kernel_id> KernelIDs, RT::PiProgram Program)
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
RT::PiProgram Program)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(std::move(Devices)), MState(State), MProgram(Program),
MKernelIDs(std::move(KernelIDs)) {
Expand All @@ -60,17 +67,17 @@ class device_image_impl {

device_image_impl(const RTDeviceBinaryImage *BinImage, context Context,
std::vector<device> Devices, bundle_state State,
std::vector<kernel_id> KernelIDs, RT::PiProgram Program,
const SpecConstMapT &SpecConstMap,
std::shared_ptr<std::vector<kernel_id>> KernelIDs,
RT::PiProgram Program, const SpecConstMapT &SpecConstMap,
const std::vector<unsigned char> &SpecConstsBlob)
: MBinImage(BinImage), MContext(std::move(Context)),
MDevices(std::move(Devices)), MState(State), MProgram(Program),
MKernelIDs(std::move(KernelIDs)), MSpecConstsBlob(SpecConstsBlob),
MSpecConstSymMap(SpecConstMap) {}

bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
return std::binary_search(MKernelIDs.begin(), MKernelIDs.end(),
KernelIDCand, LessByNameComp{});
return std::binary_search(MKernelIDs->begin(), MKernelIDs->end(),
KernelIDCand, LessByHash<kernel_id>{});
}

bool has_kernel(const kernel_id &KernelIDCand,
Expand All @@ -83,7 +90,7 @@ class device_image_impl {
}

const std::vector<kernel_id> &get_kernel_ids() const noexcept {
return MKernelIDs;
return *MKernelIDs;
}

bool has_specialization_constants() const noexcept {
Expand Down Expand Up @@ -176,7 +183,9 @@ class device_image_impl {

const context &get_context() const noexcept { return MContext; }

std::vector<kernel_id> &get_kernel_ids_ref() noexcept { return MKernelIDs; }
std::shared_ptr<std::vector<kernel_id>> &get_kernel_ids_ptr() noexcept {
return MKernelIDs;
}

std::vector<unsigned char> &get_spec_const_blob_ref() noexcept {
return MSpecConstsBlob;
Expand Down Expand Up @@ -312,7 +321,7 @@ class device_image_impl {
RT::PiProgram MProgram = nullptr;
// List of kernel ids available in this image, elements should be sorted
// according to LessByNameComp
std::vector<kernel_id> MKernelIDs;
std::shared_ptr<std::vector<kernel_id>> MKernelIDs;

// A mutex for sycnhronizing access to spec constants blob. Mutable because
// needs to be locked in the const method for getting spec constant value.
Expand Down
6 changes: 0 additions & 6 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ __SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
namespace detail {

template <class T> struct LessByHash {
bool operator()(const T &LHS, const T &RHS) const {
return getSyclObjImpl(LHS) < getSyclObjImpl(RHS);
}
};

static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
const context &Context) {
const std::vector<device> &ContextDevices = Context.get_devices();
Expand Down
Loading