Skip to content

Commit 58d2215

Browse files
Moved uses of py::object from kernel headers to implementation c++ files
This change affected linspace, arange, and full constructors. This change causes no user visible changes.
1 parent 11a933e commit 58d2215

File tree

4 files changed

+248
-208
lines changed

4 files changed

+248
-208
lines changed

dpctl/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 0 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,6 @@ template <typename Ty> class eye_kernel;
5151
namespace py = pybind11;
5252
using namespace dpctl::tensor::offset_utils;
5353

54-
/* =========== Unboxing Python scalar =============== */
55-
56-
/*!
57-
* @brief Cast pybind11 class managing Python object to specified type `T`.
58-
* @defgroup CtorKernels
59-
*/
60-
template <typename T> T unbox_py_scalar(const py::object &o)
61-
{
62-
return py::cast<T>(o);
63-
}
64-
65-
template <> inline sycl::half unbox_py_scalar<sycl::half>(const py::object &o)
66-
{
67-
float tmp = py::cast<float>(o);
68-
return static_cast<sycl::half>(tmp);
69-
}
70-
71-
// Constructor to populate tensor with linear sequence defined by
72-
// start and step data
73-
74-
typedef sycl::event (*lin_space_step_fn_ptr_t)(
75-
sycl::queue &,
76-
size_t, // num_elements
77-
const py::object &start,
78-
const py::object &step,
79-
char *, // dst_data_ptr
80-
const std::vector<sycl::event> &);
81-
8254
template <typename Ty> class LinearSequenceStepFunctor
8355
{
8456
private:
@@ -142,74 +114,9 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
142114
return lin_space_step_event;
143115
}
144116

145-
/*!
146-
* @brief Function to submit kernel to populate given contiguous memory
147-
* allocation with linear sequence specified by starting value and increment
148-
* given as Python objects.
149-
*
150-
* @param q Sycl queue to which the kernel is submitted
151-
* @param nelems Length of the sequence
152-
* @param start Starting value of the sequence as Python object. Must be
153-
* convertible to array element data type `Ty`.
154-
* @param step Increment of the sequence as Python object. Must be convertible
155-
* to array element data type `Ty`.
156-
* @param array_data Kernel accessible USM pointer to the start of array to be
157-
* populated.
158-
* @param depends List of events to wait for before starting computations, if
159-
* any.
160-
*
161-
* @return Event to wait on to ensure that computation completes.
162-
* @defgroup CtorKernels
163-
*/
164-
template <typename Ty>
165-
sycl::event lin_space_step_impl(sycl::queue &exec_q,
166-
size_t nelems,
167-
const py::object &start,
168-
const py::object &step,
169-
char *array_data,
170-
const std::vector<sycl::event> &depends)
171-
{
172-
Ty start_v;
173-
Ty step_v;
174-
try {
175-
start_v = unbox_py_scalar<Ty>(start);
176-
step_v = unbox_py_scalar<Ty>(step);
177-
} catch (const py::error_already_set &e) {
178-
throw;
179-
}
180-
181-
auto lin_space_step_event = lin_space_step_impl<Ty>(
182-
exec_q, nelems, start_v, step_v, array_data, depends);
183-
184-
return lin_space_step_event;
185-
}
186-
187-
/*!
188-
* @brief Factor to get function pointer of type `fnT` for array with elements
189-
* of type `Ty`.
190-
* @defgroup CtorKernels
191-
*/
192-
template <typename fnT, typename Ty> struct LinSpaceStepFactory
193-
{
194-
fnT get()
195-
{
196-
fnT f = lin_space_step_impl<Ty>;
197-
return f;
198-
}
199-
};
200-
201117
// Constructor to populate tensor with linear sequence defined by
202118
// start and and data
203119

204-
typedef sycl::event (*lin_space_affine_fn_ptr_t)(
205-
sycl::queue &,
206-
size_t, // num_elements
207-
const py::object &start,
208-
const py::object &end,
209-
bool include_endpoint,
210-
char *, // dst_data_ptr
211-
const std::vector<sycl::event> &);
212-
213120
template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
214121
{
215122
private:
@@ -312,70 +219,8 @@ sycl::event lin_space_affine_impl(sycl::queue &exec_q,
312219
return lin_space_affine_event;
313220
}
314221

315-
/*!
316-
* @brief Function to submit kernel to populate given contiguous memory
317-
* allocation with linear sequence specified by starting and end values given
318-
* as Python objects.
319-
*
320-
* @param exec_q Sycl queue to which kernel is submitted for execution.
321-
* @param nelems Length of the sequence
322-
* @param start Stating value of the sequence as Python object. Must be
323-
* convertible to array data element type `Ty`.
324-
* @param end End-value of the sequence as Python object. Must be convertible
325-
* to array data element type `Ty`.
326-
* @param include_endpoint Whether the end-value is included in the sequence
327-
* @param array_data Kernel accessible USM pointer to the start of array to be
328-
* populated.
329-
* @param depends List of events to wait for before starting computations, if
330-
* any.
331-
*
332-
* @return Event to wait on to ensure that computation completes.
333-
* @defgroup CtorKernels
334-
*/
335-
template <typename Ty>
336-
sycl::event lin_space_affine_impl(sycl::queue &exec_q,
337-
size_t nelems,
338-
const py::object &start,
339-
const py::object &end,
340-
bool include_endpoint,
341-
char *array_data,
342-
const std::vector<sycl::event> &depends)
343-
{
344-
Ty start_v, end_v;
345-
try {
346-
start_v = unbox_py_scalar<Ty>(start);
347-
end_v = unbox_py_scalar<Ty>(end);
348-
} catch (const py::error_already_set &e) {
349-
throw;
350-
}
351-
352-
auto lin_space_affine_event = lin_space_affine_impl<Ty>(
353-
exec_q, nelems, start_v, end_v, include_endpoint, array_data, depends);
354-
355-
return lin_space_affine_event;
356-
}
357-
358-
/*!
359-
* @brief Factory to get function pointer of type `fnT` for array data type
360-
* `Ty`.
361-
*/
362-
template <typename fnT, typename Ty> struct LinSpaceAffineFactory
363-
{
364-
fnT get()
365-
{
366-
fnT f = lin_space_affine_impl<Ty>;
367-
return f;
368-
}
369-
};
370-
371222
/* ================ Full ================== */
372223

373-
typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
374-
size_t,
375-
const py::object &,
376-
char *,
377-
const std::vector<sycl::event> &);
378-
379224
/*!
380225
* @brief Function to submit kernel to fill given contiguous memory allocation
381226
* with specified value.
@@ -408,51 +253,6 @@ sycl::event full_contig_impl(sycl::queue &q,
408253
return fill_ev;
409254
}
410255

411-
/*!
412-
* @brief Function to submit kernel to fill given contiguous memory allocation
413-
* with specified value.
414-
*
415-
* @param exec_q Sycl queue to which kernel is submitted for execution.
416-
* @param nelems Length of the sequence
417-
* @param py_value Python object representing the value to fill the array with.
418-
* Must be convertible to `dstTy`.
419-
* @param dst_p Kernel accessible USM pointer to the start of array to be
420-
* populated.
421-
* @param depends List of events to wait for before starting computations, if
422-
* any.
423-
*
424-
* @return Event to wait on to ensure that computation completes.
425-
* @defgroup CtorKernels
426-
*/
427-
template <typename dstTy>
428-
sycl::event full_contig_impl(sycl::queue &exec_q,
429-
size_t nelems,
430-
const py::object &py_value,
431-
char *dst_p,
432-
const std::vector<sycl::event> &depends)
433-
{
434-
dstTy fill_v;
435-
try {
436-
fill_v = unbox_py_scalar<dstTy>(py_value);
437-
} catch (const py::error_already_set &e) {
438-
throw;
439-
}
440-
441-
sycl::event fill_ev =
442-
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
443-
444-
return fill_ev;
445-
}
446-
447-
template <typename fnT, typename Ty> struct FullContigFactory
448-
{
449-
fnT get()
450-
{
451-
fnT f = full_contig_impl<Ty>;
452-
return f;
453-
}
454-
};
455-
456256
/* ================ Eye ================== */
457257

458258
typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,

dpctl/tensor/libtensor/source/full_ctor.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "utils/type_utils.hpp"
3636

3737
#include "full_ctor.hpp"
38+
#include "unboxing_helper.hpp"
3839

3940
namespace py = pybind11;
4041
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -48,7 +49,60 @@ namespace py_internal
4849

4950
using dpctl::utils::keep_args_alive;
5051

51-
using dpctl::tensor::kernels::constructors::full_contig_fn_ptr_t;
52+
typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
53+
size_t,
54+
const py::object &,
55+
char *,
56+
const std::vector<sycl::event> &);
57+
58+
/*!
59+
* @brief Function to submit kernel to fill given contiguous memory allocation
60+
* with specified value.
61+
*
62+
* @param exec_q Sycl queue to which kernel is submitted for execution.
63+
* @param nelems Length of the sequence
64+
* @param py_value Python object representing the value to fill the array with.
65+
* Must be convertible to `dstTy`.
66+
* @param dst_p Kernel accessible USM pointer to the start of array to be
67+
* populated.
68+
* @param depends List of events to wait for before starting computations, if
69+
* any.
70+
*
71+
* @return Event to wait on to ensure that computation completes.
72+
* @defgroup CtorKernels
73+
*/
74+
template <typename dstTy>
75+
sycl::event full_contig_impl(sycl::queue &exec_q,
76+
size_t nelems,
77+
const py::object &py_value,
78+
char *dst_p,
79+
const std::vector<sycl::event> &depends)
80+
{
81+
dstTy fill_v;
82+
83+
PythonObjectUnboxer<dstTy> unboxer{};
84+
try {
85+
fill_v = unboxer(py_value);
86+
} catch (const py::error_already_set &e) {
87+
throw;
88+
}
89+
90+
using dpctl::tensor::kernels::constructors::full_contig_impl;
91+
92+
sycl::event fill_ev =
93+
full_contig_impl<dstTy>(exec_q, nelems, fill_v, dst_p, depends);
94+
95+
return fill_ev;
96+
}
97+
98+
template <typename fnT, typename Ty> struct FullContigFactory
99+
{
100+
fnT get()
101+
{
102+
fnT f = full_contig_impl<Ty>;
103+
return f;
104+
}
105+
};
52106

53107
static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
54108

@@ -99,7 +153,6 @@ usm_ndarray_full(const py::object &py_value,
99153
void init_full_ctor_dispatch_vectors(void)
100154
{
101155
using namespace td_ns;
102-
using dpctl::tensor::kernels::constructors::FullContigFactory;
103156

104157
DispatchVectorBuilder<full_contig_fn_ptr_t, FullContigFactory, num_types>
105158
dvb;

0 commit comments

Comments
 (0)