Skip to content

Added get_strides_vector and get_shape_vector #1090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions dpctl/apis/include/dpctl4pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ class usm_ndarray : public py::object

char *get_data() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetData_(raw_ar);
Expand All @@ -842,20 +842,29 @@ class usm_ndarray : public py::object

int get_ndim() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetNDim_(raw_ar);
}

const py::ssize_t *get_shape_raw() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetShape_(raw_ar);
}

std::vector<py::ssize_t> get_shape_vector() const
{
auto raw_sh = get_shape_raw();
auto nd = get_ndim();

std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
return shape_vector;
}

py::ssize_t get_shape(int i) const
{
auto shape_ptr = get_shape_raw();
Expand All @@ -864,15 +873,43 @@ class usm_ndarray : public py::object

const py::ssize_t *get_strides_raw() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetStrides_(raw_ar);
}

std::vector<py::ssize_t> get_strides_vector() const
{
auto raw_st = get_strides_raw();
auto nd = get_ndim();

if (raw_st == nullptr) {
auto is_c_contig = is_c_contiguous();
auto is_f_contig = is_f_contiguous();
auto raw_sh = get_shape_raw();
if (is_c_contig) {
const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
return contig_strides;
}
else if (is_f_contig) {
const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
return contig_strides;
}
else {
throw std::runtime_error("Invalid array encountered when "
"building strides");
}
}
else {
std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
return st_vec;
}
}

py::ssize_t get_size() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
Expand All @@ -889,7 +926,7 @@ class usm_ndarray : public py::object

std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
int nd = api.UsmNDArray_GetNDim_(raw_ar);
Expand Down Expand Up @@ -923,7 +960,7 @@ class usm_ndarray : public py::object

sycl::queue get_queue() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
Expand All @@ -932,45 +969,45 @@ class usm_ndarray : public py::object

int get_typenum() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetTypenum_(raw_ar);
}

int get_flags() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetFlags_(raw_ar);
}

int get_elemsize() const
{
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
PyUSMArrayObject *raw_ar = usm_array_ptr();

auto const &api = ::dpctl::detail::dpctl_capi::get();
return api.UsmNDArray_GetElementSize_(raw_ar);
}

bool is_c_contiguous() const
{
int flags = this->get_flags();
int flags = get_flags();
auto const &api = ::dpctl::detail::dpctl_capi::get();
return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
}

bool is_f_contiguous() const
{
int flags = this->get_flags();
int flags = get_flags();
auto const &api = ::dpctl::detail::dpctl_capi::get();
return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
}

bool is_writable() const
{
int flags = this->get_flags();
int flags = get_flags();
auto const &api = ::dpctl::detail::dpctl_capi::get();
return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
}
Expand Down