Skip to content

Commit 9089a84

Browse files
Moved implementation of kernels out to dedicated header files.
1 parent 051e473 commit 9089a84

File tree

5 files changed

+1149
-896
lines changed

5 files changed

+1149
-896
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,55 @@ class usm_memory : public py::object
371371

372372
namespace tensor
373373
{
374+
375+
std::vector<py::ssize_t> c_contiguous_strides(int nd,
376+
const py::ssize_t *shape,
377+
py::ssize_t element_size = 1)
378+
{
379+
if (nd > 0) {
380+
std::vector<py::ssize_t> c_strides(nd, element_size);
381+
for (int ic = nd - 1; ic > 0;) {
382+
py::ssize_t next_v = c_strides[ic] * shape[ic];
383+
c_strides[--ic] = next_v;
384+
}
385+
return c_strides;
386+
}
387+
else {
388+
return std::vector<py::ssize_t>();
389+
}
390+
}
391+
392+
std::vector<py::ssize_t> f_contiguous_strides(int nd,
393+
const py::ssize_t *shape,
394+
py::ssize_t element_size = 1)
395+
{
396+
if (nd > 0) {
397+
std::vector<py::ssize_t> f_strides(nd, element_size);
398+
for (int i = 0; i < nd - 1;) {
399+
py::ssize_t next_v = f_strides[i] * shape[i];
400+
f_strides[++i] = next_v;
401+
}
402+
return f_strides;
403+
}
404+
else {
405+
return std::vector<py::ssize_t>();
406+
}
407+
}
408+
409+
std::vector<py::ssize_t>
410+
c_contiguous_strides(const std::vector<py::ssize_t> &shape,
411+
py::ssize_t element_size = 1)
412+
{
413+
return c_contiguous_strides(shape.size(), shape.data(), element_size);
414+
}
415+
416+
std::vector<py::ssize_t>
417+
f_contiguous_strides(const std::vector<py::ssize_t> &shape,
418+
py::ssize_t element_size = 1)
419+
{
420+
return f_contiguous_strides(shape.size(), shape.data(), element_size);
421+
}
422+
374423
class usm_ndarray : public py::object
375424
{
376425
public:

0 commit comments

Comments
 (0)