@@ -39,23 +39,60 @@ namespace pybind11
39
39
namespace detail
40
40
{
41
41
42
+ #define DPCTL_TYPE_CASTER (type , py_name ) \
43
+ protected: \
44
+ std::unique_ptr<type> value; \
45
+ \
46
+ public: \
47
+ static constexpr auto name = py_name; \
48
+ template < \
49
+ typename T_, \
50
+ ::pybind11::detail::enable_if_t< \
51
+ std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
52
+ int> = 0> \
53
+ static ::pybind11::handle cast(T_ *src, \
54
+ ::pybind11::return_value_policy policy, \
55
+ ::pybind11::handle parent) \
56
+ { \
57
+ if (!src) \
58
+ return ::pybind11::none().release(); \
59
+ if (policy == ::pybind11::return_value_policy::take_ownership) { \
60
+ auto h = cast(std::move(*src), policy, parent); \
61
+ delete src; \
62
+ return h; \
63
+ } \
64
+ return cast(*src, policy, parent); \
65
+ } \
66
+ operator type *() \
67
+ { \
68
+ return value.get(); \
69
+ } /* NOLINT(bugprone-macro-parentheses) */ \
70
+ operator type & () \
71
+ { \
72
+ return * value ; \
73
+ } /* NOLINT(bugprone-macro-parentheses) */ \
74
+ operator type && ( ) && \
75
+ { \
76
+ return std ::move (* value ); \
77
+ } /* NOLINT(bugprone-macro-parentheses) */ \
78
+ template < typename T_ > \
79
+ using cast_op_type = ::pybind11 ::detail ::movable_cast_op_type < T_ >
80
+
42
81
/* This type caster associates ``sycl::queue`` C++ class with
43
82
* :class:`dpctl.SyclQueue` for the purposes of generation of
44
83
* Python bindings by pybind11.
45
84
*/
46
85
template < > struct type_caster < sycl ::queue >
47
86
{
48
87
public :
49
- PYBIND11_TYPE_CASTER (sycl ::queue , _ ("dpctl.SyclQueue" ));
50
-
51
88
bool load (handle src , bool )
52
89
{
53
90
PyObject * source = src .ptr ();
54
91
if (PyObject_TypeCheck (source , & PySyclQueueType )) {
55
92
DPCTLSyclQueueRef QRef = SyclQueue_GetQueueRef (
56
93
reinterpret_cast < PySyclQueueObject * > (source ));
57
- sycl :: queue * q = reinterpret_cast < sycl ::queue * > ( QRef );
58
- value = * q ;
94
+ value = std :: make_unique < sycl ::queue > (
95
+ * ( reinterpret_cast < sycl :: queue * > ( QRef ))) ;
59
96
return true;
60
97
}
61
98
else {
@@ -69,6 +106,8 @@ template <> struct type_caster<sycl::queue>
69
106
auto tmp = SyclQueue_Make (reinterpret_cast < DPCTLSyclQueueRef > (& src ));
70
107
return handle (reinterpret_cast < PyObject * > (tmp ));
71
108
}
109
+
110
+ DPCTL_TYPE_CASTER (sycl ::queue , _ ("dpctl.SyclQueue" ));
72
111
};
73
112
74
113
/* This type caster associates ``sycl::device`` C++ class with
@@ -78,20 +117,14 @@ template <> struct type_caster<sycl::queue>
78
117
template < > struct type_caster < sycl ::device >
79
118
{
80
119
public :
81
- PYBIND11_TYPE_CASTER (sycl ::device , _ ("dpctl.SyclDevice" ));
82
-
83
120
bool load (handle src , bool )
84
121
{
85
122
PyObject * source = src .ptr ();
86
123
if (PyObject_TypeCheck (source , & PySyclDeviceType )) {
87
124
DPCTLSyclDeviceRef DRef = SyclDevice_GetDeviceRef (
88
125
reinterpret_cast < PySyclDeviceObject * > (source ));
89
- sycl ::device * d = reinterpret_cast < sycl ::device * > (DRef );
90
- value = * d ;
91
- return true;
92
- }
93
- else if (source == Py_None ) {
94
- value = sycl ::device {};
126
+ value = std ::make_unique < sycl ::device > (
127
+ * (reinterpret_cast < sycl ::device * > (DRef )));
95
128
return true;
96
129
}
97
130
else {
@@ -105,6 +138,8 @@ template <> struct type_caster<sycl::device>
105
138
auto tmp = SyclDevice_Make (reinterpret_cast < DPCTLSyclDeviceRef > (& src ));
106
139
return handle (reinterpret_cast < PyObject * > (tmp ));
107
140
}
141
+
142
+ DPCTL_TYPE_CASTER (sycl ::device , _ ("dpctl.SyclDevice" ));
108
143
};
109
144
110
145
/* This type caster associates ``sycl::context`` C++ class with
@@ -114,16 +149,14 @@ template <> struct type_caster<sycl::device>
114
149
template < > struct type_caster < sycl ::context >
115
150
{
116
151
public :
117
- PYBIND11_TYPE_CASTER (sycl ::context , _ ("dpctl.SyclContext" ));
118
-
119
152
bool load (handle src , bool )
120
153
{
121
154
PyObject * source = src .ptr ();
122
155
if (PyObject_TypeCheck (source , & PySyclContextType )) {
123
156
DPCTLSyclContextRef CRef = SyclContext_GetContextRef (
124
157
reinterpret_cast < PySyclContextObject * > (source ));
125
- sycl :: context * ctx = reinterpret_cast < sycl ::context * > ( CRef );
126
- value = * ctx ;
158
+ value = std :: make_unique < sycl ::context > (
159
+ * ( reinterpret_cast < sycl :: context * > ( CRef ))) ;
127
160
return true;
128
161
}
129
162
else {
@@ -138,6 +171,8 @@ template <> struct type_caster<sycl::context>
138
171
SyclContext_Make (reinterpret_cast < DPCTLSyclContextRef > (& src ));
139
172
return handle (reinterpret_cast < PyObject * > (tmp ));
140
173
}
174
+
175
+ DPCTL_TYPE_CASTER (sycl ::context , _ ("dpctl.SyclContext" ));
141
176
};
142
177
143
178
/* This type caster associates ``sycl::event`` C++ class with
@@ -147,16 +182,14 @@ template <> struct type_caster<sycl::context>
147
182
template < > struct type_caster < sycl ::event >
148
183
{
149
184
public :
150
- PYBIND11_TYPE_CASTER (sycl ::event , _ ("dpctl.SyclEvent" ));
151
-
152
185
bool load (handle src , bool )
153
186
{
154
187
PyObject * source = src .ptr ();
155
188
if (PyObject_TypeCheck (source , & PySyclEventType )) {
156
189
DPCTLSyclEventRef ERef = SyclEvent_GetEventRef (
157
190
reinterpret_cast < PySyclEventObject * > (source ));
158
- sycl :: event * ev = reinterpret_cast < sycl ::event * > ( ERef );
159
- value = * ev ;
191
+ value = std :: make_unique < sycl ::event > (
192
+ * ( reinterpret_cast < sycl :: event * > ( ERef ))) ;
160
193
return true;
161
194
}
162
195
else {
@@ -170,12 +203,102 @@ template <> struct type_caster<sycl::event>
170
203
auto tmp = SyclEvent_Make (reinterpret_cast < DPCTLSyclEventRef > (& src ));
171
204
return handle (reinterpret_cast < PyObject * > (tmp ));
172
205
}
206
+
207
+ DPCTL_TYPE_CASTER (sycl ::event , _ ("dpctl.SyclEvent" ));
173
208
};
174
209
} // namespace detail
175
210
} // namespace pybind11
176
211
177
212
namespace dpctl
178
213
{
214
+
215
+ namespace detail
216
+ {
217
+
218
+ struct dpctl_api
219
+ {
220
+ public :
221
+ static dpctl_api & get ()
222
+ {
223
+ static dpctl_api api ;
224
+ return api ;
225
+ }
226
+
227
+ py ::object sycl_queue_ ()
228
+ {
229
+ return * sycl_queue ;
230
+ }
231
+ py ::object default_usm_memory_ ()
232
+ {
233
+ return * default_usm_memory ;
234
+ }
235
+ py ::object default_usm_ndarray_ ()
236
+ {
237
+ return * default_usm_ndarray ;
238
+ }
239
+ py ::object as_usm_memory_ ()
240
+ {
241
+ return * as_usm_memory ;
242
+ }
243
+
244
+ private :
245
+ struct Deleter
246
+ {
247
+ void operator ()(py ::object * p ) const
248
+ {
249
+ bool guard = (Py_IsInitialized () && !_Py_IsFinalizing ());
250
+
251
+ if (guard ) {
252
+ delete p ;
253
+ }
254
+ }
255
+ };
256
+
257
+ std ::shared_ptr < py ::object > sycl_queue ;
258
+ std ::shared_ptr < py ::object > default_usm_memory ;
259
+ std ::shared_ptr < py ::object > default_usm_ndarray ;
260
+ std ::shared_ptr < py ::object > as_usm_memory ;
261
+
262
+ dpctl_api () : sycl_queue {}, default_usm_memory {}, default_usm_ndarray {}
263
+ {
264
+ import_dpctl ();
265
+
266
+ sycl ::queue q_ ;
267
+ py ::object py_sycl_queue = py ::cast (q_ );
268
+ sycl_queue = std ::shared_ptr < py ::object > (new py ::object {py_sycl_queue },
269
+ Deleter {});
270
+
271
+ py ::module_ mod_memory = py ::module_ ::import ("dpctl.memory" );
272
+ py ::object py_as_usm_memory = mod_memory .attr ("as_usm_memory" );
273
+ as_usm_memory = std ::shared_ptr < py ::object > (
274
+ new py ::object {py_as_usm_memory }, Deleter {});
275
+
276
+ auto mem_kl = mod_memory .attr ("MemoryUSMHost" );
277
+ py ::object py_default_usm_memory =
278
+ mem_kl (1 , py ::arg ("queue" ) = py_sycl_queue );
279
+ default_usm_memory = std ::shared_ptr < py ::object > (
280
+ new py ::object {py_default_usm_memory }, Deleter {});
281
+
282
+ py ::module_ mod_usmarray =
283
+ py ::module_ ::import ("dpctl.tensor._usmarray" );
284
+ auto tensor_kl = mod_usmarray .attr ("usm_ndarray" );
285
+
286
+ py ::object py_default_usm_ndarray =
287
+ tensor_kl (py ::tuple (), py ::arg ("dtype" ) = py ::str ("u1" ),
288
+ py ::arg ("buffer" ) = py_default_usm_memory );
289
+
290
+ default_usm_ndarray = std ::shared_ptr < py ::object > (
291
+ new py ::object {py_default_usm_ndarray }, Deleter {});
292
+ }
293
+
294
+ public :
295
+ dpctl_api (dpctl_api const & ) = delete ;
296
+ void operator = (dpctl_api const & ) = delete ;
297
+ ~dpctl_api (){};
298
+ };
299
+
300
+ } // namespace detail
301
+
179
302
namespace memory
180
303
{
181
304
@@ -232,7 +355,9 @@ class usm_memory : public py::object
232
355
}
233
356
// END_TOKEN
234
357
235
- usm_memory () : py ::object (default_constructed (), stolen_t {})
358
+ usm_memory ()
359
+ : py ::object (::dpctl ::detail ::dpctl_api ::get ().default_usm_memory_ (),
360
+ borrowed_t {})
236
361
{
237
362
if (!m_ptr )
238
363
throw py ::error_already_set ();
@@ -267,26 +392,12 @@ class usm_memory : public py::object
267
392
"cannot create a usm_memory from a nullptr" );
268
393
return nullptr ;
269
394
}
270
- py ::module_ m = py ::module_ ::import ("dpctl.memory" );
271
- auto convertor = m .attr ("as_usm_memory" );
272
395
273
- py ::object res ;
274
- try {
275
- res = convertor (py ::handle (o ));
276
- } catch (const py ::error_already_set & e ) {
277
- return nullptr ;
278
- }
279
- return res .ptr ();
280
- }
396
+ auto convertor = ::dpctl ::detail ::dpctl_api ::get ().as_usm_memory_ ();
281
397
282
- static PyObject * default_constructed ()
283
- {
284
- py ::module_ m = py ::module_ ::import ("dpctl.memory" );
285
- auto kl = m .attr ("MemoryUSMDevice" );
286
398
py ::object res ;
287
399
try {
288
- // allocate 1 byte
289
- res = kl (1 );
400
+ res = convertor (py ::handle (o ));
290
401
} catch (const py ::error_already_set & e ) {
291
402
return nullptr ;
292
403
}
@@ -295,10 +406,7 @@ class usm_memory : public py::object
295
406
};
296
407
297
408
} // end namespace memory
298
- } // end namespace dpctl
299
409
300
- namespace dpctl
301
- {
302
410
namespace tensor
303
411
{
304
412
class usm_ndarray : public py ::object
@@ -349,7 +457,9 @@ class usm_ndarray : public py::object
349
457
}
350
458
// END_TOKEN
351
459
352
- usm_ndarray () : py ::object (default_constructed (), stolen_t {})
460
+ usm_ndarray ()
461
+ : py ::object (::dpctl ::detail ::dpctl_api ::get ().default_usm_ndarray_ (),
462
+ borrowed_t {})
353
463
{
354
464
if (!m_ptr )
355
465
throw py ::error_already_set ();
@@ -481,21 +591,6 @@ class usm_ndarray : public py::object
481
591
482
592
return UsmNDArray_GetElementSize (raw_ar );
483
593
}
484
-
485
- private :
486
- static PyObject * default_constructed ()
487
- {
488
- py ::module_ m = py ::module_ ::import ("dpctl.tensor" );
489
- auto kl = m .attr ("usm_ndarray" );
490
- py ::object res ;
491
- try {
492
- // allocate 1 byte
493
- res = kl (py ::make_tuple (), py ::arg ("dtype" ) = "u1" );
494
- } catch (const py ::error_already_set & e ) {
495
- return nullptr ;
496
- }
497
- return res .ptr ();
498
- }
499
594
};
500
595
501
596
} // end namespace tensor
0 commit comments