Skip to content

Commit 7592c8b

Browse files
Improve performance of default constructors for usm_memory and ums_ndarray
1 parent 759a1fa commit 7592c8b

File tree

1 file changed

+151
-56
lines changed

1 file changed

+151
-56
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,60 @@ namespace pybind11
3939
namespace detail
4040
{
4141

42+
#define DPCTL_TYPE_CASTER(type, py_name) \
43+
protected: \
44+
std::unique_ptr<type> value; \
45+
\
46+
public: \
47+
static constexpr auto name = py_name; \
48+
template < \
49+
typename T_, \
50+
::pybind11::detail::enable_if_t< \
51+
std::is_same<type, ::pybind11::detail::remove_cv_t<T_>>::value, \
52+
int> = 0> \
53+
static ::pybind11::handle cast(T_ *src, \
54+
::pybind11::return_value_policy policy, \
55+
::pybind11::handle parent) \
56+
{ \
57+
if (!src) \
58+
return ::pybind11::none().release(); \
59+
if (policy == ::pybind11::return_value_policy::take_ownership) { \
60+
auto h = cast(std::move(*src), policy, parent); \
61+
delete src; \
62+
return h; \
63+
} \
64+
return cast(*src, policy, parent); \
65+
} \
66+
operator type *() \
67+
{ \
68+
return value.get(); \
69+
} /* NOLINT(bugprone-macro-parentheses) */ \
70+
operator type &() \
71+
{ \
72+
return *value; \
73+
} /* NOLINT(bugprone-macro-parentheses) */ \
74+
operator type &&() && \
75+
{ \
76+
return std::move(*value); \
77+
} /* NOLINT(bugprone-macro-parentheses) */ \
78+
template <typename T_> \
79+
using cast_op_type = ::pybind11::detail::movable_cast_op_type<T_>
80+
4281
/* This type caster associates ``sycl::queue`` C++ class with
4382
* :class:`dpctl.SyclQueue` for the purposes of generation of
4483
* Python bindings by pybind11.
4584
*/
4685
template <> struct type_caster<sycl::queue>
4786
{
4887
public:
49-
PYBIND11_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
50-
5188
bool load(handle src, bool)
5289
{
5390
PyObject *source = src.ptr();
5491
if (PyObject_TypeCheck(source, &PySyclQueueType)) {
5592
DPCTLSyclQueueRef QRef = SyclQueue_GetQueueRef(
5693
reinterpret_cast<PySyclQueueObject *>(source));
57-
sycl::queue *q = reinterpret_cast<sycl::queue *>(QRef);
58-
value = *q;
94+
value = std::make_unique<sycl::queue>(
95+
*(reinterpret_cast<sycl::queue *>(QRef)));
5996
return true;
6097
}
6198
else {
@@ -69,6 +106,8 @@ template <> struct type_caster<sycl::queue>
69106
auto tmp = SyclQueue_Make(reinterpret_cast<DPCTLSyclQueueRef>(&src));
70107
return handle(reinterpret_cast<PyObject *>(tmp));
71108
}
109+
110+
DPCTL_TYPE_CASTER(sycl::queue, _("dpctl.SyclQueue"));
72111
};
73112

74113
/* This type caster associates ``sycl::device`` C++ class with
@@ -78,20 +117,14 @@ template <> struct type_caster<sycl::queue>
78117
template <> struct type_caster<sycl::device>
79118
{
80119
public:
81-
PYBIND11_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
82-
83120
bool load(handle src, bool)
84121
{
85122
PyObject *source = src.ptr();
86123
if (PyObject_TypeCheck(source, &PySyclDeviceType)) {
87124
DPCTLSyclDeviceRef DRef = SyclDevice_GetDeviceRef(
88125
reinterpret_cast<PySyclDeviceObject *>(source));
89-
sycl::device *d = reinterpret_cast<sycl::device *>(DRef);
90-
value = *d;
91-
return true;
92-
}
93-
else if (source == Py_None) {
94-
value = sycl::device{};
126+
value = std::make_unique<sycl::device>(
127+
*(reinterpret_cast<sycl::device *>(DRef)));
95128
return true;
96129
}
97130
else {
@@ -105,6 +138,8 @@ template <> struct type_caster<sycl::device>
105138
auto tmp = SyclDevice_Make(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
106139
return handle(reinterpret_cast<PyObject *>(tmp));
107140
}
141+
142+
DPCTL_TYPE_CASTER(sycl::device, _("dpctl.SyclDevice"));
108143
};
109144

110145
/* This type caster associates ``sycl::context`` C++ class with
@@ -114,16 +149,14 @@ template <> struct type_caster<sycl::device>
114149
template <> struct type_caster<sycl::context>
115150
{
116151
public:
117-
PYBIND11_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
118-
119152
bool load(handle src, bool)
120153
{
121154
PyObject *source = src.ptr();
122155
if (PyObject_TypeCheck(source, &PySyclContextType)) {
123156
DPCTLSyclContextRef CRef = SyclContext_GetContextRef(
124157
reinterpret_cast<PySyclContextObject *>(source));
125-
sycl::context *ctx = reinterpret_cast<sycl::context *>(CRef);
126-
value = *ctx;
158+
value = std::make_unique<sycl::context>(
159+
*(reinterpret_cast<sycl::context *>(CRef)));
127160
return true;
128161
}
129162
else {
@@ -138,6 +171,8 @@ template <> struct type_caster<sycl::context>
138171
SyclContext_Make(reinterpret_cast<DPCTLSyclContextRef>(&src));
139172
return handle(reinterpret_cast<PyObject *>(tmp));
140173
}
174+
175+
DPCTL_TYPE_CASTER(sycl::context, _("dpctl.SyclContext"));
141176
};
142177

143178
/* This type caster associates ``sycl::event`` C++ class with
@@ -147,16 +182,14 @@ template <> struct type_caster<sycl::context>
147182
template <> struct type_caster<sycl::event>
148183
{
149184
public:
150-
PYBIND11_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
151-
152185
bool load(handle src, bool)
153186
{
154187
PyObject *source = src.ptr();
155188
if (PyObject_TypeCheck(source, &PySyclEventType)) {
156189
DPCTLSyclEventRef ERef = SyclEvent_GetEventRef(
157190
reinterpret_cast<PySyclEventObject *>(source));
158-
sycl::event *ev = reinterpret_cast<sycl::event *>(ERef);
159-
value = *ev;
191+
value = std::make_unique<sycl::event>(
192+
*(reinterpret_cast<sycl::event *>(ERef)));
160193
return true;
161194
}
162195
else {
@@ -170,12 +203,102 @@ template <> struct type_caster<sycl::event>
170203
auto tmp = SyclEvent_Make(reinterpret_cast<DPCTLSyclEventRef>(&src));
171204
return handle(reinterpret_cast<PyObject *>(tmp));
172205
}
206+
207+
DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
173208
};
174209
} // namespace detail
175210
} // namespace pybind11
176211

177212
namespace dpctl
178213
{
214+
215+
namespace detail
216+
{
217+
218+
struct dpctl_api
219+
{
220+
public:
221+
static dpctl_api &get()
222+
{
223+
static dpctl_api api;
224+
return api;
225+
}
226+
227+
py::object sycl_queue_()
228+
{
229+
return *sycl_queue;
230+
}
231+
py::object default_usm_memory_()
232+
{
233+
return *default_usm_memory;
234+
}
235+
py::object default_usm_ndarray_()
236+
{
237+
return *default_usm_ndarray;
238+
}
239+
py::object as_usm_memory_()
240+
{
241+
return *as_usm_memory;
242+
}
243+
244+
private:
245+
struct Deleter
246+
{
247+
void operator()(py::object *p) const
248+
{
249+
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());
250+
251+
if (guard) {
252+
delete p;
253+
}
254+
}
255+
};
256+
257+
std::shared_ptr<py::object> sycl_queue;
258+
std::shared_ptr<py::object> default_usm_memory;
259+
std::shared_ptr<py::object> default_usm_ndarray;
260+
std::shared_ptr<py::object> as_usm_memory;
261+
262+
dpctl_api() : sycl_queue{}, default_usm_memory{}, default_usm_ndarray{}
263+
{
264+
import_dpctl();
265+
266+
sycl::queue q_;
267+
py::object py_sycl_queue = py::cast(q_);
268+
sycl_queue = std::shared_ptr<py::object>(new py::object{py_sycl_queue},
269+
Deleter{});
270+
271+
py::module_ mod_memory = py::module_::import("dpctl.memory");
272+
py::object py_as_usm_memory = mod_memory.attr("as_usm_memory");
273+
as_usm_memory = std::shared_ptr<py::object>(
274+
new py::object{py_as_usm_memory}, Deleter{});
275+
276+
auto mem_kl = mod_memory.attr("MemoryUSMHost");
277+
py::object py_default_usm_memory =
278+
mem_kl(1, py::arg("queue") = py_sycl_queue);
279+
default_usm_memory = std::shared_ptr<py::object>(
280+
new py::object{py_default_usm_memory}, Deleter{});
281+
282+
py::module_ mod_usmarray =
283+
py::module_::import("dpctl.tensor._usmarray");
284+
auto tensor_kl = mod_usmarray.attr("usm_ndarray");
285+
286+
py::object py_default_usm_ndarray =
287+
tensor_kl(py::tuple(), py::arg("dtype") = py::str("u1"),
288+
py::arg("buffer") = py_default_usm_memory);
289+
290+
default_usm_ndarray = std::shared_ptr<py::object>(
291+
new py::object{py_default_usm_ndarray}, Deleter{});
292+
}
293+
294+
public:
295+
dpctl_api(dpctl_api const &) = delete;
296+
void operator=(dpctl_api const &) = delete;
297+
~dpctl_api(){};
298+
};
299+
300+
} // namespace detail
301+
179302
namespace memory
180303
{
181304

@@ -232,7 +355,9 @@ class usm_memory : public py::object
232355
}
233356
// END_TOKEN
234357

235-
usm_memory() : py::object(default_constructed(), stolen_t{})
358+
usm_memory()
359+
: py::object(::dpctl::detail::dpctl_api::get().default_usm_memory_(),
360+
borrowed_t{})
236361
{
237362
if (!m_ptr)
238363
throw py::error_already_set();
@@ -267,26 +392,12 @@ class usm_memory : public py::object
267392
"cannot create a usm_memory from a nullptr");
268393
return nullptr;
269394
}
270-
py::module_ m = py::module_::import("dpctl.memory");
271-
auto convertor = m.attr("as_usm_memory");
272395

273-
py::object res;
274-
try {
275-
res = convertor(py::handle(o));
276-
} catch (const py::error_already_set &e) {
277-
return nullptr;
278-
}
279-
return res.ptr();
280-
}
396+
auto convertor = ::dpctl::detail::dpctl_api::get().as_usm_memory_();
281397

282-
static PyObject *default_constructed()
283-
{
284-
py::module_ m = py::module_::import("dpctl.memory");
285-
auto kl = m.attr("MemoryUSMDevice");
286398
py::object res;
287399
try {
288-
// allocate 1 byte
289-
res = kl(1);
400+
res = convertor(py::handle(o));
290401
} catch (const py::error_already_set &e) {
291402
return nullptr;
292403
}
@@ -295,10 +406,7 @@ class usm_memory : public py::object
295406
};
296407

297408
} // end namespace memory
298-
} // end namespace dpctl
299409

300-
namespace dpctl
301-
{
302410
namespace tensor
303411
{
304412
class usm_ndarray : public py::object
@@ -349,7 +457,9 @@ class usm_ndarray : public py::object
349457
}
350458
// END_TOKEN
351459

352-
usm_ndarray() : py::object(default_constructed(), stolen_t{})
460+
usm_ndarray()
461+
: py::object(::dpctl::detail::dpctl_api::get().default_usm_ndarray_(),
462+
borrowed_t{})
353463
{
354464
if (!m_ptr)
355465
throw py::error_already_set();
@@ -481,21 +591,6 @@ class usm_ndarray : public py::object
481591

482592
return UsmNDArray_GetElementSize(raw_ar);
483593
}
484-
485-
private:
486-
static PyObject *default_constructed()
487-
{
488-
py::module_ m = py::module_::import("dpctl.tensor");
489-
auto kl = m.attr("usm_ndarray");
490-
py::object res;
491-
try {
492-
// allocate 1 byte
493-
res = kl(py::make_tuple(), py::arg("dtype") = "u1");
494-
} catch (const py::error_already_set &e) {
495-
return nullptr;
496-
}
497-
return res.ptr();
498-
}
499594
};
500595

501596
} // end namespace tensor

0 commit comments

Comments
 (0)