@@ -39,7 +39,9 @@ class device_image_impl {
39
39
std::vector<kernel_id> KernelIDs, RT::PiProgram Program)
40
40
: MBinImage(BinImage), MContext(std::move(Context)),
41
41
MDevices (std::move(Devices)), MState(State), MProgram(Program),
42
- MKernelIDs(std::move(KernelIDs)) {}
42
+ MKernelIDs(std::move(KernelIDs)) {
43
+ updateSpecConstSymMap ();
44
+ }
43
45
44
46
bool has_kernel (const kernel_id &KernelIDCand) const noexcept {
45
47
return std::binary_search (MKernelIDs.begin (), MKernelIDs.end (),
@@ -60,7 +62,11 @@ class device_image_impl {
60
62
}
61
63
62
64
bool has_specialization_constants () const noexcept {
63
- return !MSpecConstsBlob.empty ();
65
+ // Lock the mutex to prevent when one thread in the middle of writing a
66
+ // new value while another thread is reading the value to pass it to
67
+ // JIT compiler.
68
+ const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
69
+ return !MSpecConstSymMap.empty ();
64
70
}
65
71
66
72
bool all_specialization_constant_native () const noexcept {
@@ -72,45 +78,69 @@ class device_image_impl {
72
78
// for this spec const should be.
73
79
struct SpecConstDescT {
74
80
unsigned int ID = 0 ;
75
- unsigned int Offset = 0 ;
81
+ unsigned int CompositeOffset = 0 ;
82
+ unsigned int Size = 0 ;
83
+ unsigned int BlobOffset = 0 ;
76
84
bool IsSet = false ;
77
85
};
78
86
79
- bool has_specialization_constant (unsigned int SpecID) const noexcept {
80
- return std::any_of (MSpecConstDescs.begin (), MSpecConstDescs.end (),
81
- [SpecID](const SpecConstDescT &SpecConstDesc) {
82
- return SpecConstDesc.ID == SpecID;
83
- });
84
- }
85
-
86
- void set_specialization_constant_raw_value (unsigned int SpecID,
87
- const void *Value,
88
- size_t ValueSize) noexcept {
89
- for (const SpecConstDescT &SpecConstDesc : MSpecConstDescs)
90
- if (SpecConstDesc.ID == SpecID) {
91
- // Lock the mutex to prevent when one thread in the middle of writing a
92
- // new value while another thread is reading the value to pass it to
93
- // JIT compiler.
94
- const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
95
- std::memcpy (MSpecConstsBlob.data () + SpecConstDesc.Offset , Value,
96
- ValueSize);
97
- return ;
98
- }
87
+ bool has_specialization_constant (const char *SpecName) const noexcept {
88
+ // Lock the mutex to prevent when one thread in the middle of writing a
89
+ // new value while another thread is reading the value to pass it to
90
+ // JIT compiler.
91
+ const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
92
+ return MSpecConstSymMap.count (SpecName) != 0 ;
99
93
}
100
94
101
- void get_specialization_constant_raw_value (unsigned int SpecID,
102
- void *ValueRet,
103
- size_t ValueSize) const noexcept {
104
- for (const SpecConstDescT &SpecConstDesc : MSpecConstDescs)
105
- if (SpecConstDesc.ID == SpecID) {
106
- // Lock the mutex to prevent when one thread in the middle of writing a
107
- // new value while another thread is reading the value to pass it to
108
- // JIT compiler.
109
- const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
110
- std::memcpy (ValueRet, MSpecConstsBlob.data () + SpecConstDesc.Offset ,
111
- ValueSize);
112
- return ;
113
- }
95
+ void set_specialization_constant_raw_value (const char *SpecName,
96
+ const void *Value) noexcept {
97
+ // Lock the mutex to prevent when one thread in the middle of writing a
98
+ // new value while another thread is reading the value to pass it to
99
+ // JIT compiler.
100
+ const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
101
+
102
+ if (MSpecConstSymMap.count (std::string{SpecName}) == 0 )
103
+ return ;
104
+
105
+ std::vector<SpecConstDescT> &Descs =
106
+ MSpecConstSymMap[std::string{SpecName}];
107
+ for (SpecConstDescT &Desc : Descs) {
108
+ Desc.IsSet = true ;
109
+ std::memcpy (MSpecConstsBlob.data () + Desc.BlobOffset ,
110
+ static_cast <const char *>(Value) + Desc.CompositeOffset ,
111
+ Desc.Size );
112
+ }
113
+ }
114
+
115
+ void get_specialization_constant_raw_value (const char *SpecName,
116
+ void *ValueRet) const noexcept {
117
+ assert (is_specialization_constant_set (SpecName));
118
+ // Lock the mutex to prevent when one thread in the middle of writing a
119
+ // new value while another thread is reading the value to pass it to
120
+ // JIT compiler.
121
+ const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
122
+
123
+ // operator[] can't be used here, since it's not marked as const
124
+ const std::vector<SpecConstDescT> &Descs =
125
+ MSpecConstSymMap.at (std::string{SpecName});
126
+ for (const SpecConstDescT &Desc : Descs) {
127
+
128
+ std::memcpy (static_cast <char *>(ValueRet) + Desc.CompositeOffset ,
129
+ MSpecConstsBlob.data () + Desc.BlobOffset , Desc.Size );
130
+ }
131
+ }
132
+
133
+ bool is_specialization_constant_set (const char *SpecName) const noexcept {
134
+ // Lock the mutex to prevent when one thread in the middle of writing a
135
+ // new value while another thread is reading the value to pass it to
136
+ // JIT compiler.
137
+ const std::lock_guard<std::mutex> SpecConstLock (MSpecConstAccessMtx);
138
+ if (MSpecConstSymMap.count (std::string{SpecName}) == 0 )
139
+ return false ;
140
+
141
+ const std::vector<SpecConstDescT> &Descs =
142
+ MSpecConstSymMap.at (std::string{SpecName});
143
+ return Descs.front ().IsSet ;
114
144
}
115
145
116
146
bundle_state get_state () const noexcept { return MState; }
@@ -137,8 +167,13 @@ class device_image_impl {
137
167
return MSpecConstsBlob;
138
168
}
139
169
140
- std::vector<SpecConstDescT> &get_spec_const_offsets_ref () noexcept {
141
- return MSpecConstDescs;
170
+ const std::map<std::string, std::vector<SpecConstDescT>> &
171
+ get_spec_const_data_ref () const noexcept {
172
+ return MSpecConstSymMap;
173
+ }
174
+
175
+ std::mutex &get_spec_const_data_lock () noexcept {
176
+ return MSpecConstAccessMtx;
142
177
}
143
178
144
179
~device_image_impl () {
@@ -150,6 +185,49 @@ class device_image_impl {
150
185
}
151
186
152
187
private:
188
+ void updateSpecConstSymMap () {
189
+ if (MBinImage) {
190
+ const pi::DeviceBinaryImage::PropertyRange &SCRange =
191
+ MBinImage->getSpecConstants ();
192
+ using SCItTy = pi::DeviceBinaryImage::PropertyRange::ConstIterator;
193
+
194
+ // This variable is used to calculate spec constant value offset in a
195
+ // flat byte array.
196
+ unsigned BlobOffset = 0 ;
197
+ for (SCItTy SCIt : SCRange) {
198
+ const char *SCName = (*SCIt)->Name ;
199
+
200
+ pi::ByteArray Descriptors =
201
+ pi::DeviceBinaryProperty (*SCIt).asByteArray ();
202
+ assert (Descriptors.size () > 8 && " Unexpected property size" );
203
+
204
+ // Expected layout is vector of 3-component tuples (flattened into a
205
+ // vector of scalars), where each tuple consists of: ID of a scalar spec
206
+ // constant, (which might be a member of the composite); offset, which
207
+ // is used to calculate location of scalar member within the composite
208
+ // or zero for scalar spec constants; size of a spec constant
209
+ constexpr size_t NumElements = 3 ;
210
+ assert (((Descriptors.size () - 8 ) / sizeof (std::uint32_t )) %
211
+ NumElements ==
212
+ 0 &&
213
+ " unexpected layout of composite spec const descriptors" );
214
+ auto *It = reinterpret_cast <const std::uint32_t *>(&Descriptors[8 ]);
215
+ auto *End = reinterpret_cast <const std::uint32_t *>(&Descriptors[0 ] +
216
+ Descriptors.size ());
217
+ while (It != End) {
218
+ // The map is not locked here because updateSpecConstSymMap() is only
219
+ // supposed to be called from c'tor.
220
+ MSpecConstSymMap[std::string{SCName}].push_back (
221
+ SpecConstDescT{/* ID*/ It[0 ], /* CompositeOffset*/ It[1 ],
222
+ /* Size*/ It[2 ], BlobOffset});
223
+ BlobOffset += /* Size*/ It[2 ];
224
+ It += NumElements;
225
+ }
226
+ }
227
+ MSpecConstsBlob.resize (BlobOffset);
228
+ }
229
+ }
230
+
153
231
const RTDeviceBinaryImage *MBinImage = nullptr ;
154
232
context MContext;
155
233
std::vector<device> MDevices;
@@ -166,8 +244,9 @@ class device_image_impl {
166
244
// Binary blob which can have values of all specialization constants in the
167
245
// image
168
246
std::vector<unsigned char > MSpecConstsBlob;
169
- // Contains list of spec ID + their offsets in the MSpecConstsBlob
170
- std::vector<SpecConstDescT> MSpecConstDescs;
247
+ // Contains map of spec const names to their descriptions + offsets in
248
+ // the MSpecConstsBlob
249
+ std::map<std::string, std::vector<SpecConstDescT>> MSpecConstSymMap;
171
250
};
172
251
173
252
} // namespace detail
0 commit comments