Skip to content

Commit

Permalink
Redesign dpnp.put_along_axis and dpnp.take_along_axis thorough ex…
Browse files Browse the repository at this point in the history
…isting calls (#1636)

* Redesigned `put_along_axis` and `take_along_axis` thorugh existing calls

* Simplified check for

* Move check of array type in dpnp.prod after the TODO comment
  • Loading branch information
antonwolfy authored Dec 6, 2023
1 parent 2c8cbb5 commit a7add8e
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 321 deletions.
28 changes: 13 additions & 15 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,19 @@ enum class DPNPFuncName : size_t
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PUT, /**< Used in numpy.put() impl */
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
requires extra parameters */
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
parameters */
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
parameters */
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
parameters */
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
parameters */
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
parameters */
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
parameters */
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
parameters */
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
Expand Down
22 changes: 0 additions & 22 deletions dpnp/backend/kernels/dpnp_krnl_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,19 +796,6 @@ void (*dpnp_put_along_axis_default_c)(void *,
size_t) =
dpnp_put_along_axis_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_put_along_axis_ext_c)(DPCTLSyclQueueRef,
void *,
long *,
void *,
size_t,
const shape_elem_type *,
size_t,
size_t,
size_t,
const DPCTLEventVectorRef) =
dpnp_put_along_axis_c<_DataType>;

template <typename _DataType, typename _IndecesType>
class dpnp_take_c_kernel;

Expand Down Expand Up @@ -1005,15 +992,6 @@ void func_map_init_indexing_func(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_put_along_axis_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_put_along_axis_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_put_along_axis_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_put_along_axis_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_put_along_axis_ext_c<double>};

fmap[DPNPFuncName::DPNP_FN_TAKE][eft_BLN][eft_INT] = {
eft_BLN, (void *)dpnp_take_default_c<bool, int32_t>};
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_RNG_POISSON_EXT
DPNP_FN_RNG_POWER
DPNP_FN_RNG_POWER_EXT
DPNP_FN_PUT_ALONG_AXIS
DPNP_FN_PUT_ALONG_AXIS_EXT
DPNP_FN_RNG_RAYLEIGH
DPNP_FN_RNG_RAYLEIGH_EXT
DPNP_FN_RNG_SHUFFLE
Expand Down
129 changes: 0 additions & 129 deletions dpnp/dpnp_algo/dpnp_algo_indexing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ __all__ += [
"dpnp_diagonal",
"dpnp_fill_diagonal",
"dpnp_indices",
"dpnp_put_along_axis",
"dpnp_putmask",
"dpnp_select",
"dpnp_take_along_axis",
"dpnp_tril_indices",
"dpnp_tril_indices_from",
"dpnp_triu_indices",
Expand All @@ -69,16 +67,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpct
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
void *, void * , shape_elem_type * , const size_t,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_3in_with_axis_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
void * ,
void * ,
void * ,
const size_t,
shape_elem_type * ,
const size_t,
const size_t,
const size_t,
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_choose(utils.dpnp_descriptor x1, list choices1):
Expand Down Expand Up @@ -283,35 +271,6 @@ cpdef object dpnp_indices(dimensions):
return dpnp_result


cpdef dpnp_put_along_axis(dpnp_descriptor arr, dpnp_descriptor indices, dpnp_descriptor values, int axis):
cdef shape_type_c arr_shape = arr.shape
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PUT_ALONG_AXIS_EXT, param1_type, param1_type)

utils.get_common_usm_allocation(arr, indices) # check USM allocation is common
_, _, result_sycl_queue = utils.get_common_usm_allocation(arr, values)

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef custom_indexing_3in_with_axis_func_ptr_t func = <custom_indexing_3in_with_axis_func_ptr_t > kernel_data.ptr

cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
arr.get_data(),
indices.get_data(),
values.get_data(),
axis,
arr_shape.data(),
arr.ndim,
indices.size,
values.size,
NULL) # dep_events_ref

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)


cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils.dpnp_descriptor values):
cdef int values_size = values.size

Expand Down Expand Up @@ -341,94 +300,6 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default)
return res_array


cpdef object dpnp_take_along_axis(object arr, object indices, int axis):
cdef long size_arr = arr.size
cdef shape_type_c shape_arr = arr.shape
cdef shape_type_c output_shape
cdef long size_indices = indices.size
res_type = arr.dtype

if axis != arr.ndim - 1:
res_shape_list = list(shape_arr)
res_shape_list[axis] = 1
res_shape = tuple(res_shape_list)

output_shape = (0,) * (len(shape_arr) - 1)
ind = 0
for id, shape_axis in enumerate(shape_arr):
if id != axis:
output_shape[ind] = shape_axis
ind += 1

prod = 1
for i in range(len(output_shape)):
if output_shape[i] != 0:
prod *= output_shape[i]

result_array = dpnp.empty((prod, ), dtype=res_type)
ind_array = [None] * prod
arr_shape_offsets = [None] * len(shape_arr)
acc = 1

for i in range(len(shape_arr)):
ind = len(shape_arr) - 1 - i
arr_shape_offsets[ind] = acc
acc *= shape_arr[ind]

output_shape_offsets = [None] * len(shape_arr)
acc = 1

for i in range(len(output_shape)):
ind = len(output_shape) - 1 - i
output_shape_offsets[ind] = acc
acc *= output_shape[ind]
result_offsets = arr_shape_offsets[:] # need copy. not a reference
result_offsets[axis] = 0

for source_idx in range(size_arr):

# reconstruct x,y,z from linear source_idx
xyz = []
remainder = source_idx
for i in arr_shape_offsets:
quotient, remainder = divmod(remainder, i)
xyz.append(quotient)

# extract result axis
result_axis = []
for idx, offset in enumerate(xyz):
if idx != axis:
result_axis.append(offset)

# Construct result offset
result_offset = 0
for i, result_axis_val in enumerate(result_axis):
result_offset += (output_shape_offsets[i] * result_axis_val)

arr_elem = arr.item(source_idx)
if ind_array[result_offset] is None:
ind_array[result_offset] = 0
else:
ind_array[result_offset] += 1

if ind_array[result_offset] % size_indices == indices.item(result_offset % size_indices):
result_array[result_offset] = arr_elem

dpnp_result_array = dpnp.reshape(result_array, res_shape)
return dpnp_result_array

else:
result_array = utils_py.create_output_descriptor_py(shape_arr, res_type, None).get_pyobj()

result_array_flatiter = result_array.flat

for i in range(size_arr):
ind = size_indices * (i // size_indices) + indices.item(i % size_indices)
result_array_flatiter[i] = arr.item(ind)

return result_array


cpdef tuple dpnp_tril_indices(n, k=0, m=None):
array1 = []
array2 = []
Expand Down
37 changes: 37 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"array_equal",
"asnumpy",
"astype",
"check_supported_arrays_type",
"convert_single_elem_array_to_scalar",
"default_float_type",
"dpnp_queue_initialize",
Expand Down Expand Up @@ -203,6 +204,42 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
return dpnp_array._create_from_usm_ndarray(array_obj)


def check_supported_arrays_type(*arrays, scalar_type=False):
"""
Return ``True`` if each array has either type of scalar,
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
But if any array has unsupported type, ``TypeError`` will be raised.
Parameters
----------
arrays : {dpnp_array, usm_ndarray}
Input arrays to check for supported types.
scalar_type : {bool}, optional
A scalar type is also considered as supported if flag is True.
Returns
-------
out : bool
``True`` if each type of input `arrays` is supported type,
``False`` otherwise.
Raises
------
TypeError
If any input array from `arrays` is of unsupported array type.
"""

for a in arrays:
if scalar_type and dpnp.isscalar(a) or is_supported_array_type(a):
continue

raise TypeError(
"An array must be any of supported type, but got {}".format(type(a))
)
return True


def convert_single_elem_array_to_scalar(obj, keepdims=False):
"""Convert array with single element to scalar."""

Expand Down
Loading

0 comments on commit a7add8e

Please sign in to comment.