Skip to content

Commit ba3d657

Browse files
author
Alexander Batashev
authored
[SYCL] Add support for set(get)_specialization_constant (#3501)
This patch introduces Spec Constant Name to Spec ID mapping for `device_image`s and implements `kernel_bundle::set(get)_specialization_constant()` and `handler::set(get)_specialization_constant` member functions. See [4.9.5.2. Setting and getting the value of a specialization constant](https://www.khronos.org/registry/SYCL/specs/sycl-2020/html/sycl-2020.html#_setting_and_getting_the_value_of_a_specialization_constant).
1 parent cec6469 commit ba3d657

File tree

12 files changed

+322
-99
lines changed

12 files changed

+322
-99
lines changed

sycl/include/CL/sycl/detail/kernel_desc.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ __SYCL_INLINE_NAMESPACE(cl) {
1717
namespace sycl {
1818
namespace detail {
1919

20+
// This guard is needed because the libsycl.so can be compiled with C++ <=14
21+
// while the code requires C++17. This code is not supposed to be used by the
22+
// libsycl.so so it should not be a problem.
23+
#if __cplusplus > 201402L
24+
template <auto &S> struct specialization_id_name_generator {};
25+
#endif
26+
2027
#ifndef __SYCL_DEVICE_ONLY__
2128
#define _Bool bool
2229
#endif
@@ -49,6 +56,14 @@ template <class Name> struct SpecConstantInfo {
4956
static constexpr const char *getName() { return ""; }
5057
};
5158

59+
#if __cplusplus >= 201703L
60+
// Translates SYCL 2020 specialization constant type to its name.
61+
template <auto &SpecName> const char *get_spec_constant_symbolic_ID() {
62+
return __builtin_unique_stable_name(
63+
specialization_id_name_generator<SpecName>);
64+
}
65+
#endif
66+
5267
#ifndef __SYCL_UNNAMED_LAMBDA__
5368
template <class KernelNameType> struct KernelInfo {
5469
static constexpr unsigned getNumParams() { return 0; }

sycl/include/CL/sycl/detail/sycl_fe_intrins.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,12 @@ SYCL_EXTERNAL T __sycl_getCompositeSpecConstantValue(const char *ID);
3737
// are not available.
3838
template <typename T>
3939
SYCL_EXTERNAL T __sycl_getScalar2020SpecConstantValue(const char *SymbolicID,
40-
void *DefaultValue,
41-
void *RTBuffer);
40+
const void *DefaultValue,
41+
const void *RTBuffer);
4242

4343
template <typename T>
44-
SYCL_EXTERNAL T __sycl_getComposite2020SpecConstantValue(const char *SymbolicID,
45-
void *DefaultValue,
46-
void *RTBuffer);
44+
SYCL_EXTERNAL T __sycl_getComposite2020SpecConstantValue(
45+
const char *SymbolicID, const void *DefaultValue, const void *RTBuffer);
4746

4847
// Request a fixed-size allocation in local address space at kernel scope.
4948
extern "C" SYCL_EXTERNAL __attribute__((opencl_local)) std::uint8_t *

sycl/include/CL/sycl/handler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ class __SYCL_EXPORT handler {
10481048
}
10491049

10501050
template <auto &SpecName>
1051-
typename std::remove_reference_t<decltype(SpecName)>::type
1051+
typename std::remove_reference_t<decltype(SpecName)>::value_type
10521052
get_specialization_constant() const {
10531053

10541054
std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImplPtr =

sycl/include/CL/sycl/kernel_bundle.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ class __SYCL_EXPORT kernel_bundle_plain {
168168
// \returns an iterator to the last device image kernel_bundle contains
169169
const device_image_plain *end() const;
170170

171+
bool has_specialization_constant_impl(const char *SpecName) const noexcept;
172+
173+
void set_specialization_constant_impl(const char *SpecName,
174+
void *Value) noexcept;
175+
176+
void get_specialization_constant_impl(const char *SpecName, void *Value) const
177+
noexcept;
178+
179+
bool is_specialization_constant_set(const char *SpecName) const noexcept;
180+
171181
detail::KernelBundleImplPtr impl;
172182
};
173183

@@ -247,9 +257,8 @@ class kernel_bundle : public detail::kernel_bundle_plain {
247257
/// \returns true if any device image in the kernel_bundle uses specialization
248258
/// constant whose address is SpecName
249259
template <auto &SpecName> bool has_specialization_constant() const noexcept {
250-
throw sycl::runtime_error(
251-
"kernel_bundle::has_specialization_constant is not implemented yet",
252-
PI_INVALID_OPERATION);
260+
const char *SpecSymName = detail::get_spec_constant_symbolic_ID<SpecName>();
261+
return has_specialization_constant_impl(SpecSymName);
253262
}
254263

255264
/// Sets the value of the specialization constant whose address is SpecName
@@ -259,20 +268,27 @@ class kernel_bundle : public detail::kernel_bundle_plain {
259268
typename = detail::enable_if_t<_State == bundle_state::input>>
260269
void set_specialization_constant(
261270
typename std::remove_reference_t<decltype(SpecName)>::value_type Value) {
262-
(void)Value;
263-
throw sycl::runtime_error(
264-
"kernel_bundle::set_specialization_constant is not implemented yet",
265-
PI_INVALID_OPERATION);
271+
const char *SpecSymName = detail::get_spec_constant_symbolic_ID<SpecName>();
272+
set_specialization_constant_impl(SpecSymName, &Value);
266273
}
267274

268-
/// The value of the specialization constant whose address is SpecName for
269-
/// this kernel bundle.
275+
/// \returns the value of the specialization constant whose address is
276+
/// SpecName for this kernel bundle.
270277
template <auto &SpecName>
271278
typename std::remove_reference_t<decltype(SpecName)>::value_type
272279
get_specialization_constant() const {
273-
throw sycl::runtime_error(
274-
"kernel_bundle::get_specialization_constant is not implemented yet",
275-
PI_INVALID_OPERATION);
280+
const char *SpecSymName = detail::get_spec_constant_symbolic_ID<SpecName>();
281+
if (!is_specialization_constant_set(SpecSymName))
282+
return SpecName.getDefaultValue();
283+
284+
using SCType =
285+
typename std::remove_reference_t<decltype(SpecName)>::value_type;
286+
287+
std::array<char *, sizeof(SCType)> RetValue;
288+
289+
get_specialization_constant_impl(SpecSymName, RetValue.data());
290+
291+
return *reinterpret_cast<SCType *>(RetValue.data());
276292
}
277293
#endif
278294

sycl/include/CL/sycl/kernel_handler.hpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,22 @@
88

99
#pragma once
1010

11-
__SYCL_INLINE_NAMESPACE(cl) {
12-
namespace sycl {
13-
namespace detail {
14-
15-
// This guard is needed because the libsycl.so can compiled with C++ <=14
16-
// while the code requires C++17. This code is not supposed to be used by the
17-
// libsycl.so so it should not be a problem.
18-
#if __cplusplus > 201402L
19-
template <auto &S> struct specialization_id_name_generator {};
20-
#endif
11+
#include <CL/sycl/detail/kernel_desc.hpp>
12+
#include <CL/sycl/exception.hpp>
2113

22-
} // namespace detail
14+
#include <type_traits>
2315

16+
__SYCL_INLINE_NAMESPACE(cl) {
17+
namespace sycl {
2418
/// Reading the value of a specialization constant
2519
///
2620
/// \ingroup sycl_api
2721
class kernel_handler {
2822
public:
2923
#if __cplusplus > 201402L
3024
template <auto &S>
31-
typename std::remove_reference_t<decltype(S)> get_specialization_constant() {
25+
typename std::remove_reference_t<decltype(S)>::value_type
26+
get_specialization_constant() {
3227
#ifdef __SYCL_DEVICE_ONLY__
3328
return getSpecializationConstantOnDevice<S>();
3429
#else
@@ -48,16 +43,20 @@ class kernel_handler {
4843
}
4944

5045
#ifdef __SYCL_DEVICE_ONLY__
51-
template <auto &S, typename T = std::remove_reference_t<decltype(S)>,
52-
std::enable_if_t<std::is_fundamental_v<T>> * = nullptr>
46+
template <
47+
auto &S,
48+
typename T = typename std::remove_reference_t<decltype(S)>::value_type,
49+
std::enable_if_t<std::is_fundamental_v<T>> * = nullptr>
5350
T getSpecializationConstantOnDevice() {
5451
const char *SymbolicID = __builtin_unique_stable_name(
5552
detail::specialization_id_name_generator<S>);
5653
return __sycl_getScalar2020SpecConstantValue<T>(
5754
SymbolicID, &S, MSpecializationConstantsBuffer);
5855
}
59-
template <auto &S, typename T = std::remove_reference_t<decltype(S)>,
60-
std::enable_if_t<std::is_compound_v<T>> * = nullptr>
56+
template <
57+
auto &S,
58+
typename T = typename std::remove_reference_t<decltype(S)>::value_type,
59+
std::enable_if_t<std::is_compound_v<T>> * = nullptr>
6160
T getSpecializationConstantOnDevice() {
6261
const char *SymbolicID = __builtin_unique_stable_name(
6362
detail::specialization_id_name_generator<S>);

sycl/include/CL/sycl/specialization_id.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ template <typename T> class specialization_id {
2828
specialization_id &operator=(specialization_id &&rhs) = delete;
2929

3030
private:
31+
template <bundle_state State> friend class kernel_bundle;
32+
T getDefaultValue() const noexcept { return MDefaultValue; }
33+
3134
T MDefaultValue;
3235
};
3336

sycl/source/detail/device_image_impl.hpp

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ class device_image_impl {
3939
std::vector<kernel_id> KernelIDs, RT::PiProgram Program)
4040
: MBinImage(BinImage), MContext(std::move(Context)),
4141
MDevices(std::move(Devices)), MState(State), MProgram(Program),
42-
MKernelIDs(std::move(KernelIDs)) {}
42+
MKernelIDs(std::move(KernelIDs)) {
43+
updateSpecConstSymMap();
44+
}
4345

4446
bool has_kernel(const kernel_id &KernelIDCand) const noexcept {
4547
return std::binary_search(MKernelIDs.begin(), MKernelIDs.end(),
@@ -60,7 +62,11 @@ class device_image_impl {
6062
}
6163

6264
bool has_specialization_constants() const noexcept {
63-
return !MSpecConstsBlob.empty();
65+
// Lock the mutex to prevent when one thread in the middle of writing a
66+
// new value while another thread is reading the value to pass it to
67+
// JIT compiler.
68+
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
69+
return !MSpecConstSymMap.empty();
6470
}
6571

6672
bool all_specialization_constant_native() const noexcept {
@@ -72,45 +78,69 @@ class device_image_impl {
7278
// for this spec const should be.
7379
struct SpecConstDescT {
7480
unsigned int ID = 0;
75-
unsigned int Offset = 0;
81+
unsigned int CompositeOffset = 0;
82+
unsigned int Size = 0;
83+
unsigned int BlobOffset = 0;
7684
bool IsSet = false;
7785
};
7886

79-
bool has_specialization_constant(unsigned int SpecID) const noexcept {
80-
return std::any_of(MSpecConstDescs.begin(), MSpecConstDescs.end(),
81-
[SpecID](const SpecConstDescT &SpecConstDesc) {
82-
return SpecConstDesc.ID == SpecID;
83-
});
84-
}
85-
86-
void set_specialization_constant_raw_value(unsigned int SpecID,
87-
const void *Value,
88-
size_t ValueSize) noexcept {
89-
for (const SpecConstDescT &SpecConstDesc : MSpecConstDescs)
90-
if (SpecConstDesc.ID == SpecID) {
91-
// Lock the mutex to prevent when one thread in the middle of writing a
92-
// new value while another thread is reading the value to pass it to
93-
// JIT compiler.
94-
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
95-
std::memcpy(MSpecConstsBlob.data() + SpecConstDesc.Offset, Value,
96-
ValueSize);
97-
return;
98-
}
87+
bool has_specialization_constant(const char *SpecName) const noexcept {
88+
// Lock the mutex to prevent when one thread in the middle of writing a
89+
// new value while another thread is reading the value to pass it to
90+
// JIT compiler.
91+
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
92+
return MSpecConstSymMap.count(SpecName) != 0;
9993
}
10094

101-
void get_specialization_constant_raw_value(unsigned int SpecID,
102-
void *ValueRet,
103-
size_t ValueSize) const noexcept {
104-
for (const SpecConstDescT &SpecConstDesc : MSpecConstDescs)
105-
if (SpecConstDesc.ID == SpecID) {
106-
// Lock the mutex to prevent when one thread in the middle of writing a
107-
// new value while another thread is reading the value to pass it to
108-
// JIT compiler.
109-
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
110-
std::memcpy(ValueRet, MSpecConstsBlob.data() + SpecConstDesc.Offset,
111-
ValueSize);
112-
return;
113-
}
95+
void set_specialization_constant_raw_value(const char *SpecName,
96+
const void *Value) noexcept {
97+
// Lock the mutex to prevent when one thread in the middle of writing a
98+
// new value while another thread is reading the value to pass it to
99+
// JIT compiler.
100+
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
101+
102+
if (MSpecConstSymMap.count(std::string{SpecName}) == 0)
103+
return;
104+
105+
std::vector<SpecConstDescT> &Descs =
106+
MSpecConstSymMap[std::string{SpecName}];
107+
for (SpecConstDescT &Desc : Descs) {
108+
Desc.IsSet = true;
109+
std::memcpy(MSpecConstsBlob.data() + Desc.BlobOffset,
110+
static_cast<const char *>(Value) + Desc.CompositeOffset,
111+
Desc.Size);
112+
}
113+
}
114+
115+
void get_specialization_constant_raw_value(const char *SpecName,
116+
void *ValueRet) const noexcept {
117+
assert(is_specialization_constant_set(SpecName));
118+
// Lock the mutex to prevent when one thread in the middle of writing a
119+
// new value while another thread is reading the value to pass it to
120+
// JIT compiler.
121+
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
122+
123+
// operator[] can't be used here, since it's not marked as const
124+
const std::vector<SpecConstDescT> &Descs =
125+
MSpecConstSymMap.at(std::string{SpecName});
126+
for (const SpecConstDescT &Desc : Descs) {
127+
128+
std::memcpy(static_cast<char *>(ValueRet) + Desc.CompositeOffset,
129+
MSpecConstsBlob.data() + Desc.BlobOffset, Desc.Size);
130+
}
131+
}
132+
133+
bool is_specialization_constant_set(const char *SpecName) const noexcept {
134+
// Lock the mutex to prevent when one thread in the middle of writing a
135+
// new value while another thread is reading the value to pass it to
136+
// JIT compiler.
137+
const std::lock_guard<std::mutex> SpecConstLock(MSpecConstAccessMtx);
138+
if (MSpecConstSymMap.count(std::string{SpecName}) == 0)
139+
return false;
140+
141+
const std::vector<SpecConstDescT> &Descs =
142+
MSpecConstSymMap.at(std::string{SpecName});
143+
return Descs.front().IsSet;
114144
}
115145

116146
bundle_state get_state() const noexcept { return MState; }
@@ -137,8 +167,13 @@ class device_image_impl {
137167
return MSpecConstsBlob;
138168
}
139169

140-
std::vector<SpecConstDescT> &get_spec_const_offsets_ref() noexcept {
141-
return MSpecConstDescs;
170+
const std::map<std::string, std::vector<SpecConstDescT>> &
171+
get_spec_const_data_ref() const noexcept {
172+
return MSpecConstSymMap;
173+
}
174+
175+
std::mutex &get_spec_const_data_lock() noexcept {
176+
return MSpecConstAccessMtx;
142177
}
143178

144179
~device_image_impl() {
@@ -150,6 +185,49 @@ class device_image_impl {
150185
}
151186

152187
private:
188+
void updateSpecConstSymMap() {
189+
if (MBinImage) {
190+
const pi::DeviceBinaryImage::PropertyRange &SCRange =
191+
MBinImage->getSpecConstants();
192+
using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;
193+
194+
// This variable is used to calculate spec constant value offset in a
195+
// flat byte array.
196+
unsigned BlobOffset = 0;
197+
for (SCItTy SCIt : SCRange) {
198+
const char *SCName = (*SCIt)->Name;
199+
200+
pi::ByteArray Descriptors =
201+
pi::DeviceBinaryProperty(*SCIt).asByteArray();
202+
assert(Descriptors.size() > 8 && "Unexpected property size");
203+
204+
// Expected layout is vector of 3-component tuples (flattened into a
205+
// vector of scalars), where each tuple consists of: ID of a scalar spec
206+
// constant, (which might be a member of the composite); offset, which
207+
// is used to calculate location of scalar member within the composite
208+
// or zero for scalar spec constants; size of a spec constant
209+
constexpr size_t NumElements = 3;
210+
assert(((Descriptors.size() - 8) / sizeof(std::uint32_t)) %
211+
NumElements ==
212+
0 &&
213+
"unexpected layout of composite spec const descriptors");
214+
auto *It = reinterpret_cast<const std::uint32_t *>(&Descriptors[8]);
215+
auto *End = reinterpret_cast<const std::uint32_t *>(&Descriptors[0] +
216+
Descriptors.size());
217+
while (It != End) {
218+
// The map is not locked here because updateSpecConstSymMap() is only
219+
// supposed to be called from c'tor.
220+
MSpecConstSymMap[std::string{SCName}].push_back(
221+
SpecConstDescT{/*ID*/ It[0], /*CompositeOffset*/ It[1],
222+
/*Size*/ It[2], BlobOffset});
223+
BlobOffset += /*Size*/ It[2];
224+
It += NumElements;
225+
}
226+
}
227+
MSpecConstsBlob.resize(BlobOffset);
228+
}
229+
}
230+
153231
const RTDeviceBinaryImage *MBinImage = nullptr;
154232
context MContext;
155233
std::vector<device> MDevices;
@@ -166,8 +244,9 @@ class device_image_impl {
166244
// Binary blob which can have values of all specialization constants in the
167245
// image
168246
std::vector<unsigned char> MSpecConstsBlob;
169-
// Contains list of spec ID + their offsets in the MSpecConstsBlob
170-
std::vector<SpecConstDescT> MSpecConstDescs;
247+
// Contains map of spec const names to their descriptions + offsets in
248+
// the MSpecConstsBlob
249+
std::map<std::string, std::vector<SpecConstDescT>> MSpecConstSymMap;
171250
};
172251

173252
} // namespace detail

0 commit comments

Comments
 (0)