Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign dpnp.put_along_axis and dpnp.take_along_axis thorough existing calls #1636

Merged
merged 4 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,19 @@ enum class DPNPFuncName : size_t
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PUT, /**< Used in numpy.put() impl */
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
requires extra parameters */
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
parameters */
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
parameters */
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
parameters */
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
parameters */
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
parameters */
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
parameters */
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
parameters */
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
Expand Down
22 changes: 0 additions & 22 deletions dpnp/backend/kernels/dpnp_krnl_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,19 +796,6 @@ void (*dpnp_put_along_axis_default_c)(void *,
size_t) =
dpnp_put_along_axis_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_put_along_axis_ext_c)(DPCTLSyclQueueRef,
void *,
long *,
void *,
size_t,
const shape_elem_type *,
size_t,
size_t,
size_t,
const DPCTLEventVectorRef) =
dpnp_put_along_axis_c<_DataType>;

template <typename _DataType, typename _IndecesType>
class dpnp_take_c_kernel;

Expand Down Expand Up @@ -1005,15 +992,6 @@ void func_map_init_indexing_func(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_put_along_axis_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_put_along_axis_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_put_along_axis_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_put_along_axis_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_put_along_axis_ext_c<double>};

fmap[DPNPFuncName::DPNP_FN_TAKE][eft_BLN][eft_INT] = {
eft_BLN, (void *)dpnp_take_default_c<bool, int32_t>};
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_RNG_POISSON_EXT
DPNP_FN_RNG_POWER
DPNP_FN_RNG_POWER_EXT
DPNP_FN_PUT_ALONG_AXIS
DPNP_FN_PUT_ALONG_AXIS_EXT
DPNP_FN_RNG_RAYLEIGH
DPNP_FN_RNG_RAYLEIGH_EXT
DPNP_FN_RNG_SHUFFLE
Expand Down
129 changes: 0 additions & 129 deletions dpnp/dpnp_algo/dpnp_algo_indexing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ __all__ += [
"dpnp_diagonal",
"dpnp_fill_diagonal",
"dpnp_indices",
"dpnp_put_along_axis",
"dpnp_putmask",
"dpnp_select",
"dpnp_take_along_axis",
"dpnp_tril_indices",
"dpnp_tril_indices_from",
"dpnp_triu_indices",
Expand All @@ -69,16 +67,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpct
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
void *, void * , shape_elem_type * , const size_t,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_3in_with_axis_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
void * ,
void * ,
void * ,
const size_t,
shape_elem_type * ,
const size_t,
const size_t,
const size_t,
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_choose(utils.dpnp_descriptor x1, list choices1):
Expand Down Expand Up @@ -283,35 +271,6 @@ cpdef object dpnp_indices(dimensions):
return dpnp_result


cpdef dpnp_put_along_axis(dpnp_descriptor arr, dpnp_descriptor indices, dpnp_descriptor values, int axis):
cdef shape_type_c arr_shape = arr.shape
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PUT_ALONG_AXIS_EXT, param1_type, param1_type)

utils.get_common_usm_allocation(arr, indices) # check USM allocation is common
_, _, result_sycl_queue = utils.get_common_usm_allocation(arr, values)

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef custom_indexing_3in_with_axis_func_ptr_t func = <custom_indexing_3in_with_axis_func_ptr_t > kernel_data.ptr

cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
arr.get_data(),
indices.get_data(),
values.get_data(),
axis,
arr_shape.data(),
arr.ndim,
indices.size,
values.size,
NULL) # dep_events_ref

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)


cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils.dpnp_descriptor values):
cdef int values_size = values.size

Expand Down Expand Up @@ -341,94 +300,6 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default)
return res_array


cpdef object dpnp_take_along_axis(object arr, object indices, int axis):
cdef long size_arr = arr.size
cdef shape_type_c shape_arr = arr.shape
cdef shape_type_c output_shape
cdef long size_indices = indices.size
res_type = arr.dtype

if axis != arr.ndim - 1:
res_shape_list = list(shape_arr)
res_shape_list[axis] = 1
res_shape = tuple(res_shape_list)

output_shape = (0,) * (len(shape_arr) - 1)
ind = 0
for id, shape_axis in enumerate(shape_arr):
if id != axis:
output_shape[ind] = shape_axis
ind += 1

prod = 1
for i in range(len(output_shape)):
if output_shape[i] != 0:
prod *= output_shape[i]

result_array = dpnp.empty((prod, ), dtype=res_type)
ind_array = [None] * prod
arr_shape_offsets = [None] * len(shape_arr)
acc = 1

for i in range(len(shape_arr)):
ind = len(shape_arr) - 1 - i
arr_shape_offsets[ind] = acc
acc *= shape_arr[ind]

output_shape_offsets = [None] * len(shape_arr)
acc = 1

for i in range(len(output_shape)):
ind = len(output_shape) - 1 - i
output_shape_offsets[ind] = acc
acc *= output_shape[ind]
result_offsets = arr_shape_offsets[:] # need copy. not a reference
result_offsets[axis] = 0

for source_idx in range(size_arr):

# reconstruct x,y,z from linear source_idx
xyz = []
remainder = source_idx
for i in arr_shape_offsets:
quotient, remainder = divmod(remainder, i)
xyz.append(quotient)

# extract result axis
result_axis = []
for idx, offset in enumerate(xyz):
if idx != axis:
result_axis.append(offset)

# Construct result offset
result_offset = 0
for i, result_axis_val in enumerate(result_axis):
result_offset += (output_shape_offsets[i] * result_axis_val)

arr_elem = arr.item(source_idx)
if ind_array[result_offset] is None:
ind_array[result_offset] = 0
else:
ind_array[result_offset] += 1

if ind_array[result_offset] % size_indices == indices.item(result_offset % size_indices):
result_array[result_offset] = arr_elem

dpnp_result_array = dpnp.reshape(result_array, res_shape)
return dpnp_result_array

else:
result_array = utils_py.create_output_descriptor_py(shape_arr, res_type, None).get_pyobj()

result_array_flatiter = result_array.flat

for i in range(size_arr):
ind = size_indices * (i // size_indices) + indices.item(i % size_indices)
result_array_flatiter[i] = arr.item(ind)

return result_array


cpdef tuple dpnp_tril_indices(n, k=0, m=None):
array1 = []
array2 = []
Expand Down
40 changes: 40 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"array_equal",
"asnumpy",
"astype",
"check_supported_arrays_type",
"convert_single_elem_array_to_scalar",
"default_float_type",
"dpnp_queue_initialize",
Expand Down Expand Up @@ -203,6 +204,45 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
return dpnp_array._create_from_usm_ndarray(array_obj)


def check_supported_arrays_type(*arrays, scalar_type=False):
"""
Return ``True`` if each array has either type of scalar,
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
But if any array has unsupported type, ``TypeError`` will be raised.

Parameters
----------
arrays : {dpnp_array, usm_ndarray}
Input arrays to check for supported types.
scalar_type : {bool}, optional
A scalar type is also considered as supported if flag is True.

Returns
-------
out : bool
``True`` if each type of input `arrays` is supported type,
``False`` otherwise.

Raises
------
TypeError
If any input array from `arrays` is of unsupported array type.

"""

for a in arrays:
if not (
(scalar_type or is_supported_array_type(a))
and is_supported_array_or_scalar(a)
):
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"An array must be any of supported type, but got {}".format(
type(a)
)
)
return True


def convert_single_elem_array_to_scalar(obj, keepdims=False):
"""Convert array with single element to scalar."""

Expand Down
Loading