diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index ca59c8027b6..409755af919 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -86,8 +86,6 @@ enum class DPNPFuncName : size_t parameters */ DPNP_FN_AROUND, /**< Used in numpy.around() impl */ DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */ - DPNP_FN_ASTYPE_EXT, /**< Used in numpy.astype() impl, requires extra - parameters */ DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() impl */ DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() impl */ DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_common.cpp b/dpnp/backend/kernels/dpnp_krnl_common.cpp index 875c8ee6d1d..d575f8bdb96 100644 --- a/dpnp/backend/kernels/dpnp_krnl_common.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_common.cpp @@ -101,14 +101,6 @@ template void (*dpnp_astype_default_c)(const void *, void *, const size_t) = dpnp_astype_c<_DataType, _ResultType>; -template -DPCTLSyclEventRef (*dpnp_astype_ext_c)(DPCTLSyclQueueRef, - const void *, - void *, - const size_t, - const DPCTLEventVectorRef) = - dpnp_astype_c<_DataType, _ResultType>; - template @@ -1035,63 +1027,6 @@ void func_map_init_linalg(func_map_t &fmap) (void *) dpnp_astype_default_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_BLN][eft_BLN] = { - eft_BLN, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_BLN][eft_INT] = { - eft_INT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_BLN][eft_LNG] = { - eft_LNG, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_BLN][eft_FLT] = { - eft_FLT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_BLN][eft_DBL] = { - eft_DBL, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_INT][eft_BLN] = { - eft_BLN, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_INT][eft_LNG] = { - eft_LNG, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_INT][eft_FLT] = { - eft_FLT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_INT][eft_DBL] = { - eft_DBL, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_LNG][eft_BLN] = { - eft_BLN, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_LNG][eft_INT] = { - eft_INT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_LNG][eft_FLT] = { - eft_FLT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_LNG][eft_DBL] = { - eft_DBL, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_FLT][eft_BLN] = { - eft_BLN, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_FLT][eft_INT] = { - eft_INT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_FLT][eft_LNG] = { - eft_LNG, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_FLT][eft_DBL] = { - eft_DBL, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_DBL][eft_BLN] = { - eft_BLN, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_DBL][eft_INT] = { - eft_INT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_DBL][eft_LNG] = { - eft_LNG, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_DBL][eft_FLT] = { - eft_FLT, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_astype_ext_c}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_C64][eft_C64] = { - eft_C64, - (void *)dpnp_astype_ext_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_ASTYPE_EXT][eft_C128][eft_C128] = { - eft_C128, - (void *)dpnp_astype_ext_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_dot_default_c}; fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_LNG] = { diff --git a/dpnp/dparray.pxd b/dpnp/dparray.pxd index 9f94db42f40..95a2963c7d0 100644 --- a/dpnp/dparray.pxd +++ b/dpnp/dparray.pxd @@ -50,4 +50,3 @@ cdef class dparray: cdef void * get_data(self) cpdef item(self, id=*) - cpdef dparray astype(self, dtype, order=*, casting=*, subok=*, copy=*) diff --git a/dpnp/dparray.pyx b/dpnp/dparray.pyx index 7866f46ec26..947065b23ef 100644 --- a/dpnp/dparray.pyx +++ b/dpnp/dparray.pyx @@ -41,12 +41,11 @@ from libcpp cimport bool as cpp_bool import numpy from dpnp.dpnp_algo import ( - dpnp_astype, dpnp_flatten, ) # to avoid interference with Python internal functions -from dpnp.dpnp_iface import asnumpy +from dpnp.dpnp_iface import asnumpy, astype from dpnp.dpnp_iface import get_dpnp_descriptor as iface_get_dpnp_descriptor from dpnp.dpnp_iface import prod as iface_prod from dpnp.dpnp_iface import sum as iface_sum @@ -870,47 +869,36 @@ cdef class dparray: def __truediv__(self, other): return divide(self, other) - cpdef dparray astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): - """Copy the array with data type casting. + def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): + """ + Copy the array with data type casting. - Args: - dtype: Target type. - order ({'C', 'F', 'A', 'K'}): Row-major (C-style) or column-major (Fortran-style) order. - When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise. - And when ``order`` is 'K', it keeps strides as closely as possible. - copy (bool): If it is False and no cast happens, then this method returns the array itself. - Otherwise, a copy is returned. + Parameters + ---------- + dtype : dtype + Target data type. + order : {'C', 'F', 'A', 'K'} + Row-major (C-style) or column-major (Fortran-style) order. + When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise. + And when ``order`` is 'K', it keeps strides as closely as possible. + copy : bool + If it is False and no cast happens, then this method returns the array itself. + Otherwise, a copy is returned. - Returns: + Returns + ------- + out : dpnp.ndarray If ``copy`` is False and no cast is required, then the array itself is returned. Otherwise, it returns a (possibly casted) copy of the array. - .. note:: - This method currently does not support `order``, `casting``, ``copy``, and ``subok`` arguments. - - .. seealso:: :meth:`numpy.ndarray.astype` + Limitations + ----------- + Parameter `subok` is supported with default value. + Otherwise ``NotImplementedError`` exception will be raised. """ - if casting is not 'unsafe': - pass - elif subok is not True: - pass - elif copy is not True: - pass - elif order is not 'K': - pass - elif self.dtype == numpy.complex128 or dtype == numpy.complex128: - pass - elif self.dtype == numpy.complex64 or dtype == numpy.complex64: - pass - else: - self_desc = iface_get_dpnp_descriptor(self) - return dpnp_astype(self_desc, dtype).get_pyobj() - - result = dp2nd_array(self).astype(dtype=dtype, order=order, casting=casting, subok=subok, copy=copy) - - return nd2dp_array(result) + return astype(self, dtype, order=order, casting=casting, subok=subok, copy=copy) def conj(self): """ diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index da1a5b1dedf..a82b6fafbc8 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -42,8 +42,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_ARGMIN_EXT DPNP_FN_ARGSORT DPNP_FN_ARGSORT_EXT - DPNP_FN_ASTYPE - DPNP_FN_ASTYPE_EXT DPNP_FN_CBRT DPNP_FN_CBRT_EXT DPNP_FN_CHOLESKY @@ -329,7 +327,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*dpnp_reduction_c_t)(c_dpctl.DPCTLSyclQueueRe const long*, const c_dpctl.DPCTLEventVectorRef) -cpdef dpnp_descriptor dpnp_astype(dpnp_descriptor x1, dtype) cpdef dpnp_descriptor dpnp_flatten(dpnp_descriptor x1) diff --git a/dpnp/dpnp_algo/dpnp_algo.pyx b/dpnp/dpnp_algo/dpnp_algo.pyx index 351e912eeda..2d3be5f88a0 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pyx +++ b/dpnp/dpnp_algo/dpnp_algo.pyx @@ -54,7 +54,6 @@ import operator import numpy __all__ = [ - "dpnp_astype", "dpnp_flatten", "dpnp_queue_initialize", ] @@ -74,9 +73,6 @@ include "dpnp_algo_statistics.pxi" include "dpnp_algo_trigonometric.pxi" -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_astype_t)(c_dpctl.DPCTLSyclQueueRef, - const void *, void * , const size_t, - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_flatten_t)(c_dpctl.DPCTLSyclQueueRef, void *, const size_t, const size_t, const shape_elem_type * , const shape_elem_type * , @@ -86,37 +82,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_flatten_t)(c_dpctl.DPCTLSyclQueueR const c_dpctl.DPCTLEventVectorRef) -cpdef utils.dpnp_descriptor dpnp_astype(utils.dpnp_descriptor x1, dtype): - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ASTYPE_EXT, param1_type, param2_type) - - x1_obj = x1.get_array() - - # create result array with type given by FPTR data - cdef shape_type_c result_shape = x1.shape - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_dpnp_astype_t func = kernel_data.ptr - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, x1.get_data(), result.get_data(), x1.size, NULL) - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef utils.dpnp_descriptor dpnp_flatten(utils.dpnp_descriptor x1): cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 1610b16ceb6..ca16bb9a70c 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -596,32 +596,64 @@ def asnumpy(self): return dpt.asnumpy(self._array_obj) def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): - """Copy the array with data type casting. + """ + Copy the array with data type casting. - Args: - dtype: Target type. - order ({'C', 'F', 'A', 'K'}): Row-major (C-style) or column-major (Fortran-style) order. - When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise. - And when ``order`` is 'K', it keeps strides as closely as possible. - copy (bool): If it is False and no cast happens, then this method returns the array itself. - Otherwise, a copy is returned. + For full documentation refer to :obj:`numpy.ndarray.astype`. - Returns: - If ``copy`` is False and no cast is required, then the array itself is returned. - Otherwise, it returns a (possibly casted) copy of the array. + Parameters + ---------- + x1 : {dpnp.ndarray, usm_ndarray} + Array data type casting. + dtype : dtype + Target data type. + order : {'C', 'F', 'A', 'K'} + Row-major (C-style) or column-major (Fortran-style) order. + When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise. + And when ``order`` is 'K', it keeps strides as closely as possible. + copy : bool + If it is False and no cast happens, then this method returns the array itself. + Otherwise, a copy is returned. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Defaults to 'unsafe' for backwards compatibility. + 'no' means the data types should not be cast at all. + 'equiv' means only byte-order changes are allowed. + 'safe' means only casts which can preserve values are allowed. + 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed. + 'unsafe' means any data conversions may be done. + copy : bool, optional + By default, astype always returns a newly allocated array. If this is set to false, and the dtype, + order, and subok requirements are satisfied, the input array is returned instead of a copy. - .. note:: - This method currently does not support `order``, `casting``, ``copy``, and ``subok`` arguments. + Returns + ------- + arr_t : dpnp.ndarray + Unless `copy` is ``False`` and the other conditions for returning the input array + are satisfied, `arr_t` is a new array of the same shape as the input array, + with dtype, order given by dtype, order. - .. seealso:: :meth:`numpy.ndarray.astype` + Limitations + ----------- + Parameter `subok` is supported with default value. + Otherwise ``NotImplementedError`` exception will be raised. + + Examples + -------- + >>> import dpnp as np + >>> x = np.array([1, 2, 2.5]) + >>> x + array([1. , 2. , 2.5]) + >>> x.astype(int) + array([1, 2, 2]) """ - new_array = self.__new__(dpnp_array) - new_array._array_obj = dpt.astype( - self._array_obj, dtype, order=order, casting=casting, copy=copy - ) - return new_array + if subok is not True: + raise NotImplementedError( + f"subok={subok} is currently not supported" + ) + + return dpnp.astype(self, dtype, order=order, casting=casting, copy=copy) # 'base', # 'byteswap', diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index ecb3e48aaf6..d6d5b3a4861 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -46,6 +46,7 @@ import dpctl.tensor as dpt import numpy +import dpnp from dpnp.dpnp_algo import * from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import * @@ -149,42 +150,57 @@ def asnumpy(input, order="C"): return numpy.asarray(input, order=order) -def astype(x1, dtype, order="K", casting="unsafe", subok=True, copy=True): - """Copy the array with data type casting.""" - if isinstance(x1, dpnp_array): - return x1.astype(dtype, order=order, casting=casting, copy=copy) - - if isinstance(x1, dpt.usm_ndarray): - return dpt.astype(x1, dtype, order=order, casting=casting, copy=copy) - - x1_desc = get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if not x1_desc: - pass - elif order != "K": - pass - elif casting != "unsafe": - pass - elif not subok: - pass - elif not copy: - pass - elif x1_desc.dtype == numpy.complex128 or dtype == numpy.complex128: - pass - elif x1_desc.dtype == numpy.complex64 or dtype == numpy.complex64: - pass - else: - return dpnp_astype(x1_desc, dtype).get_pyobj() - - return call_origin( - numpy.ndarray.astype, - x1, - dtype, - order=order, - casting=casting, - subok=subok, - copy=copy, +def astype(x1, dtype, order="K", casting="unsafe", copy=True): + """ + Copy the array with data type casting. + + Parameters + ---------- + x1 : {dpnp.ndarray, usm_ndarray} + Array data type casting. + dtype : dtype + Target data type. + order : {'C', 'F', 'A', 'K'} + Row-major (C-style) or column-major (Fortran-style) order. + When ``order`` is 'A', it uses 'F' if ``a`` is column-major and uses 'C' otherwise. + And when ``order`` is 'K', it keeps strides as closely as possible. + copy : bool + If it is False and no cast happens, then this method returns the array itself. + Otherwise, a copy is returned. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. Defaults to 'unsafe' for backwards compatibility. + 'no' means the data types should not be cast at all. + 'equiv' means only byte-order changes are allowed. + 'safe' means only casts which can preserve values are allowed. + 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed. + 'unsafe' means any data conversions may be done. + copy : bool, optional + By default, astype always returns a newly allocated array. If this is set to false, and the dtype, + order, and subok requirements are satisfied, the input array is returned instead of a copy. + + Returns + ------- + arr_t : dpnp.ndarray + Unless `copy` is ``False`` and the other conditions for returning the input array + are satisfied, `arr_t` is a new array of the same shape as the input array, + with dtype, order given by dtype, order. + + """ + + if order is None: + order = "K" + + x1_obj = dpnp.get_usm_ndarray(x1) + array_obj = dpt.astype( + x1_obj, dtype, order=order, casting=casting, copy=copy ) + # return x1 if dpctl returns a zero copy of x1_obj + if array_obj is x1_obj and isinstance(x1, dpnp_array): + return x1 + + return dpnp_array._create_from_usm_ndarray(array_obj) + def convert_single_elem_array_to_scalar(obj, keepdims=False): """Convert array with single element to scalar.""" @@ -350,37 +366,45 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None): Utility to process complementary keyword arguments 'device' and 'sycl_queue' in subsequent calls of functions from `dpctl.tensor` module. - If both arguments 'device' and 'sycl_queue' have default value `None` + If both arguments 'device' and 'sycl_queue' have default value ``None`` and 'obj' has `sycl_queue` attribute, it assumes that Compute Follows Data approach has to be applied and so the resulting SYCL queue will be normalized based on the queue value from 'obj'. - Args: - obj (optional): A python object. Can be an instance of `dpnp_array`, - `dpctl.tensor.usm_ndarray`, an object representing SYCL USM allocation - and implementing `__sycl_usm_array_interface__` protocol, - an instance of `numpy.ndarray`, an object supporting Python buffer protocol, - a Python scalar, or a (possibly nested) sequence of Python scalars. - sycl_queue (:class:`dpctl.SyclQueue`, optional): - explicitly indicates where USM allocation is done - and the population code (if any) is executed. - Value `None` is interpreted as get the SYCL queue - from `obj` parameter if not None, from `device` keyword, - or use default queue. - Default: None - device (string, :class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue, - :class:`dpctl.tensor.Device`, optional): - array-API keyword indicating non-partitioned SYCL device - where array is allocated. + Parameters + ---------- + obj : object, optional + A python object. Can be an instance of `dpnp_array`, + `dpctl.tensor.usm_ndarray`, an object representing SYCL USM allocation + and implementing `__sycl_usm_array_interface__` protocol, an instance + of `numpy.ndarray`, an object supporting Python buffer protocol, + a Python scalar, or a (possibly nested) sequence of Python scalars. + sycl_queue : class:`dpctl.SyclQueue`, optional + A queue which explicitly indicates where USM allocation is done + and the population code (if any) is executed. + Value ``None`` is interpreted as to get the SYCL queue from either + `obj` parameter if not ``None`` or from `device` keyword, + or to use default queue. + device : {string, :class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue, + :class:`dpctl.tensor.Device`}, optional + An array-API keyword indicating non-partitioned SYCL device + where array is allocated. + Returns - :class:`dpctl.SyclQueue` object normalized by `normalize_queue_device` call + ------- + sycl_queue: dpctl.SyclQueue + A :class:`dpctl.SyclQueue` object normalized by `normalize_queue_device` call of `dpctl.tensor` module invoked with 'device' and 'sycl_queue' values. If both incoming 'device' and 'sycl_queue' are None and 'obj' has `sycl_queue` attribute, the normalization will be performed for 'obj.sycl_queue' value. - Raises: - TypeError: if argument is not of the expected type, or keywords - imply incompatible queues. + + Raises + ------ + TypeError + If argument is not of the expected type, or keywords imply incompatible queues. + """ + if ( device is None and sycl_queue is None @@ -389,12 +413,9 @@ def get_normalized_queue_device(obj=None, device=None, sycl_queue=None): ): sycl_queue = obj.sycl_queue - # TODO: remove check dpt._device has attribute 'normalize_queue_device' - if hasattr(dpt._device, "normalize_queue_device"): - return dpt._device.normalize_queue_device( - sycl_queue=sycl_queue, device=device - ) - return sycl_queue + return dpt._device.normalize_queue_device( + sycl_queue=sycl_queue, device=device + ) def get_usm_ndarray(a): diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index a3ebbc6124c..3c2e18a5d4c 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -56,6 +56,7 @@ "atleast_3d", "broadcast_arrays", "broadcast_to", + "can_cast", "concatenate", "copyto", "expand_dims", @@ -402,6 +403,47 @@ def broadcast_to(array, /, shape, subok=False): return dpnp_array._create_from_usm_ndarray(new_array) +def can_cast(from_, to, casting="safe"): + """ + Returns ``True`` if cast between data types can occur according to the casting rule. + + If `from` is a scalar or array scalar, also returns ``True`` if the scalar value can + be cast without overflow or truncation to an integer. + + For full documentation refer to :obj:`numpy.can_cast`. + + Parameters + ---------- + from : dpnp.array, dtype + Source data type. + to : dtype + Target data type. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + Returns + ------- + out: bool + True if cast can occur according to the casting rule. + + See Also + -------- + :obj:`dpnp.result_type` : Returns the type that results from applying the NumPy + type promotion rules to the arguments. + + """ + + if dpnp.is_supported_array_type(to): + raise TypeError("Cannot construct a dtype from an array") + + dtype_from = ( + from_.dtype + if dpnp.is_supported_array_type(from_) + else dpnp.dtype(from_) + ) + return dpt.can_cast(dtype_from, to, casting) + + def concatenate( arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind" ): @@ -519,7 +561,7 @@ def copyto(dst, src, casting="same_kind", where=True): elif not dpnp.is_supported_array_type(src): src = dpnp.array(src, sycl_queue=dst.sycl_queue) - if not dpt.can_cast(src.dtype, dst.dtype, casting=casting): + if not dpnp.can_cast(src.dtype, dst.dtype, casting=casting): raise TypeError( f"Cannot cast from {src.dtype} to {dst.dtype} " f"according to the rule {casting}." diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 9ea1ba4fc13..8fdb70ddf4d 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -3,7 +3,13 @@ tests/test_histograms.py::TestHistogram::test_density tests/test_random.py::TestDistributionsMultivariateNormal::test_moments tests/test_random.py::TestDistributionsMultivariateNormal::test_output_shape_check tests/test_random.py::TestDistributionsMultivariateNormal::test_seed +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.complex64)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), object)] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.vstack([x, x]).T] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)] tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-trapz-data19] tests/test_sycl_queue.py::test_1in_1out[opencl:cpu:0-trapz-data19] @@ -70,11 +76,6 @@ tests/test_linalg.py::test_norm1[None-3-[7]] tests/test_linalg.py::test_norm1[None-3-[1, 2]] tests/test_linalg.py::test_norm1[None-3-[1, 0]] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([[i, i] for i in x])] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)] - tests/test_strides.py::test_strides_1arg[(10,)-None-cbrt] tests/test_strides.py::test_strides_1arg[(10,)-None-degrees] tests/test_strides.py::test_strides_1arg[(10,)-None-exp2] @@ -110,32 +111,15 @@ tests/test_umath.py::test_umaths[('spacing', 'd')] tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_0_{shape=()}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_1_{shape=(1,)}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_2_{shape=(2, 3)}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_3_{order='C', shape=(2, 3)}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_4_{order='F', shape=(2, 3)}::test_item -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_negative -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_not_slice -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_wrong_dtype -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_wrong_shape -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_0d -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_0d_raise -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_itemsize -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_non_contiguous_raise -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestNumPyArrayCopyView_param_0_{src_order='C'}::test_isinstance_numpy_view_copy_f -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestNumPyArrayCopyView_param_1_{src_order='F'}::test_isinstance_numpy_view_copy_f + +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_copied +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_transposed tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all_keepdims diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index d707cb03cfd..301fbc456ae 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -9,8 +9,13 @@ tests/test_random.py::TestPermutationsTestShuffle::test_no_miss_numbers[float32] tests/test_random.py::TestPermutationsTestShuffle::test_no_miss_numbers[float64] tests/test_random.py::TestPermutationsTestShuffle::test_no_miss_numbers[int32] tests/test_random.py::TestPermutationsTestShuffle::test_no_miss_numbers[int64] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.array([])] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.float32)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.complex64)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), object)] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.vstack([x, x]).T] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]] +tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)] tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-copy-data3] tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-cumprod-data4] @@ -131,13 +136,8 @@ tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsPois tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsPoisson_param_2_{lam_shape=(3, 2), shape=(4, 3, 2)}::test_poisson tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsPoisson_param_3_{lam_shape=(3, 2), shape=(3, 2)}::test_poisson -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([[i, i] for i in x])] - tests/test_histograms.py::TestHistogram::test_density -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), object)] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.vstack([x, x]).T] tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayItemRaise_param_0_{shape=(0,)}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayItemRaise_param_1_{shape=(2, 3)}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayItemRaise_param_2_{shape=(1, 0, 1)}::test_item @@ -237,34 +237,13 @@ tests/test_linalg.py::test_matrix_rank[None-[[1, 2], [3, 4]]-float32] tests/test_linalg.py::test_matrix_rank[None-[[1, 2], [3, 4]]-int64] tests/test_linalg.py::test_matrix_rank[None-[[1, 2], [3, 4]]-int32] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)] - tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_negative -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_not_slice -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_wrong_dtype -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy_wrong_shape -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_0d -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_0d_raise -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_itemsize -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_view_non_contiguous_raise -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestNumPyArrayCopyView_param_0_{src_order='C'}::test_isinstance_numpy_view_copy_f -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestNumPyArrayCopyView_param_1_{src_order='F'}::test_isinstance_numpy_view_copy_f + +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_copied +tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayFlatten::test_flatten_order_transposed tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all tests/third_party/cupy/core_tests/test_ndarray_reduction.py::TestArrayReduction::test_ptp_all_keepdims diff --git a/tests/test_arraymanipulation.py b/tests/test_arraymanipulation.py index 34e673cb82f..bf61634e52c 100644 --- a/tests/test_arraymanipulation.py +++ b/tests/test_arraymanipulation.py @@ -928,3 +928,14 @@ def test_subok_error(): with pytest.raises(NotImplementedError): dpnp.broadcast_arrays(x, subok=True) dpnp.broadcast_to(x, (4, 4), subok=True) + + +def test_can_cast(): + X = dpnp.ones((2, 2), dtype=dpnp.int64) + pytest.raises(TypeError, dpnp.can_cast, X, 1) + pytest.raises(TypeError, dpnp.can_cast, X, X) + + X_np = numpy.ones((2, 2), dtype=numpy.int64) + assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32") + assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32) + assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64) diff --git a/tests/test_dparray.py b/tests/test_dparray.py index 47d8c5ca931..e3f4e80a4cb 100644 --- a/tests/test_dparray.py +++ b/tests/test_dparray.py @@ -13,6 +13,7 @@ ) +@pytest.mark.usefixtures("suppress_complex_warning") @pytest.mark.parametrize("res_dtype", get_all_dtypes()) @pytest.mark.parametrize("arr_dtype", get_all_dtypes()) @pytest.mark.parametrize( @@ -28,6 +29,12 @@ def test_astype(arr, arr_dtype, res_dtype): assert_allclose(expected, result) +def test_astype_subok_error(): + x = dpnp.ones((4)) + with pytest.raises(NotImplementedError): + x.astype("i4", subok=False) + + @pytest.mark.parametrize("arr_dtype", get_all_dtypes()) @pytest.mark.parametrize( "arr", diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 85d015f3775..2a25e7de573 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1647,6 +1647,7 @@ def test_sum_empty_out(dtype): assert_array_equal(out.asnumpy(), numpy.array(0, dtype=dtype)) +@pytest.mark.usefixtures("suppress_complex_warning") @pytest.mark.parametrize( "shape", [ diff --git a/tests/test_random.py b/tests/test_random.py index 17383c56610..1fd058e2d13 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -151,9 +151,7 @@ def test_randn_normal_distribution(): assert math.isclose(mean, expected_mean, abs_tol=0.03) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsBeta(TestDistribution): def test_moments(self): @@ -179,9 +177,7 @@ def test_seed(self): self.check_seed("beta", {"a": a, "b": b}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsBinomial(TestDistribution): def test_extreme_value(self): @@ -220,9 +216,7 @@ def test_seed(self): self.check_seed("binomial", {"n": n, "p": p}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsChisquare(TestDistribution): def test_invalid_args(self): @@ -234,9 +228,7 @@ def test_seed(self): self.check_seed("chisquare", {"df": df}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsExponential(TestDistribution): def test_invalid_args(self): @@ -248,9 +240,7 @@ def test_seed(self): self.check_seed("exponential", {"scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsF(TestDistribution): def test_moments(self): @@ -284,9 +274,7 @@ def test_seed(self): self.check_seed("f", {"dfnum": dfnum, "dfden": dfden}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsGamma(TestDistribution): def test_moments(self): @@ -337,9 +325,7 @@ def test_seed(self): self.check_seed("geometric", {"p": p}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsGumbel(TestDistribution): def test_extreme_value(self): @@ -371,9 +357,7 @@ def test_seed(self): self.check_seed("gumbel", {"loc": loc, "scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-6001" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsHypergeometric(TestDistribution): def test_extreme_value(self): @@ -460,9 +444,7 @@ def test_seed(self): ) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsLaplace(TestDistribution): def test_extreme_value(self): @@ -493,9 +475,7 @@ def test_seed(self): self.check_seed("laplace", {"loc": loc, "scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsLogistic(TestDistribution): def test_moments(self): @@ -521,9 +501,7 @@ def test_seed(self): self.check_seed("logistic", {"loc": loc, "scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsLognormal(TestDistribution): def test_extreme_value(self): @@ -559,9 +537,7 @@ def test_seed(self): self.check_seed("lognormal", {"mean": mean, "sigma": sigma}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsMultinomial(TestDistribution): def test_extreme_value(self): @@ -658,9 +634,7 @@ def test_seed(self): self.check_seed("multivariate_normal", {"mean": mean, "cov": cov}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsNegativeBinomial(TestDistribution): def test_extreme_value(self): @@ -692,9 +666,7 @@ def test_seed(self): self.check_seed("negative_binomial", {"n": n, "p": p}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") class TestDistributionsNormal(TestDistribution): def test_extreme_value(self): loc = 5 @@ -724,9 +696,7 @@ def test_seed(self): self.check_seed("normal", {"loc": loc, "scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsNoncentralChisquare: @pytest.mark.parametrize( @@ -772,9 +742,7 @@ def test_seed(self, df): assert_allclose(a1, a2, rtol=1e-07, atol=0) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsPareto(TestDistribution): def test_moments(self): @@ -793,9 +761,7 @@ def test_seed(self): self.check_seed("pareto", {"a": a}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsPoisson(TestDistribution): def test_extreme_value(self): @@ -817,9 +783,7 @@ def test_seed(self): self.check_seed("poisson", {"lam": lam}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsPower(TestDistribution): def test_moments(self): @@ -839,9 +803,7 @@ def test_seed(self): self.check_seed("power", {"a": a}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsRayleigh(TestDistribution): def test_extreme_value(self): @@ -883,9 +845,7 @@ def test_seed(self): self.check_seed("standard_exponential", {}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsStandardGamma(TestDistribution): def test_extreme_value(self): @@ -917,9 +877,7 @@ def test_seed(self): self.check_seed("standard_normal", {}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsStandardT(TestDistribution): def test_moments(self): @@ -938,9 +896,7 @@ def test_seed(self): self.check_seed("standard_t", {"df": 10.0}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsTriangular(TestDistribution): def test_moments(self): @@ -987,9 +943,7 @@ def test_seed(self): ) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") class TestDistributionsUniform(TestDistribution): def test_extreme_value(self): low = 1.0 @@ -1014,9 +968,7 @@ def test_seed(self): self.check_seed("uniform", {"low": low, "high": high}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsVonmises: @pytest.mark.parametrize( @@ -1057,9 +1009,7 @@ def test_seed(self, kappa): assert_allclose(a1, a2, rtol=1e-07, atol=0) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsWald(TestDistribution): def test_moments(self): @@ -1091,9 +1041,7 @@ def test_seed(self): self.check_seed("wald", {"mean": mean, "scale": scale}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsWeibull(TestDistribution): def test_extreme_value(self): @@ -1110,9 +1058,7 @@ def test_seed(self): self.check_seed("weibull", {"a": a}) -@pytest.mark.skipif( - not has_support_aspect64(), reason="Failed on Iris Xe: SAT-5989" -) +@pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDistributionsZipf(TestDistribution): def test_invalid_args(self): @@ -1159,14 +1105,15 @@ def test_no_miss_numbers(self, dtype): actual_x = dpnp.sort(output_x) assert_array_equal(actual_x, desired_x) + @pytest.mark.skipif(not has_support_aspect64(), reason="Failed on Iris Xe") @pytest.mark.parametrize( "conv", [ lambda x: dpnp.array([]), - # lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8), + lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8), lambda x: dpnp.astype(dpnp.asarray(x), dpnp.float32), - # lambda x: dpnp.asarray(x).astype(dpnp.complex64), - # lambda x: dpnp.astype(dpnp.asarray(x), object), + lambda x: dpnp.asarray(x).astype(dpnp.complex64), + lambda x: dpnp.astype(dpnp.asarray(x), object), lambda x: dpnp.asarray([[i, i] for i in x]), lambda x: dpnp.vstack([x, x]).T, lambda x: ( @@ -1180,10 +1127,10 @@ def test_no_miss_numbers(self, dtype): ], ids=[ "lambda x: dpnp.array([])", - # 'lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)', + "lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)", "lambda x: dpnp.astype(dpnp.asarray(x), dpnp.float32)", - # 'lambda x: dpnp.asarray(x).astype(dpnp.complex64)', - # 'lambda x: dpnp.astype(dpnp.asarray(x), object)', + "lambda x: dpnp.asarray(x).astype(dpnp.complex64)", + "lambda x: dpnp.astype(dpnp.asarray(x), object)", "lambda x: dpnp.asarray([[i, i] for i in x])", "lambda x: dpnp.vstack([x, x]).T", "lambda x: (dpnp.asarray([(i, i) for i in x], [" diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 4334dd49099..136ff1c7181 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -166,6 +166,7 @@ def test_array_creation_follow_device(func, args, kwargs, device): assert_sycl_queue_equal(y.sycl_queue, x.sycl_queue) +@pytest.mark.skip("muted until the issue reported by SAT-5969 is resolved") @pytest.mark.parametrize( "func, args, kwargs", [ diff --git a/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py b/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py index 88915e760dc..dd5e56b0a81 100644 --- a/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py +++ b/tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py @@ -1,13 +1,9 @@ -import unittest - import numpy import pytest import dpnp as cupy from tests.third_party.cupy import testing -# from cupy import util - def astype_without_warning(x, dtype, *args, **kwargs): dtype = numpy.dtype(dtype) @@ -18,8 +14,14 @@ def astype_without_warning(x, dtype, *args, **kwargs): return x.astype(dtype, *args, **kwargs) -@testing.gpu -class TestArrayCopyAndView(unittest.TestCase): +def get_strides(xp, a): + if xp is numpy: + return tuple(el // a.itemsize for el in a.strides) + return a.strides + + +@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") +class TestView: @testing.numpy_cupy_array_equal() def test_view(self, xp): a = testing.shaped_arange((4,), xp, dtype=numpy.float32) @@ -54,6 +56,173 @@ def test_view_non_contiguous_raise(self, dtype): with pytest.raises(ValueError): a.view(dtype=dtype) + @testing.for_dtypes([numpy.int16, numpy.int64]) + @testing.with_requires("numpy>=1.23") + def test_view_f_contiguous(self, dtype): + for xp in (numpy, cupy): + a = testing.shaped_arange((2, 2, 2), xp, dtype=numpy.float32) + a = a.T + with pytest.raises(ValueError): + a.view(dtype=dtype) + + def test_view_assert_divisible(self): + for xp in (numpy, cupy): + a = testing.shaped_arange((3,), xp, dtype=numpy.int32) + with pytest.raises(ValueError): + a.view(dtype=numpy.int64) + + @testing.for_dtypes([numpy.float32, numpy.float64]) + @testing.numpy_cupy_array_equal(strides_check=True) + def test_view_relaxed_contiguous(self, xp, dtype): + a = testing.shaped_arange((1, 3, 5), xp, dtype=dtype) + a = xp.moveaxis(a, 0, 2) # (3, 5, 1) + b = a.view(dtype=numpy.int32) + return b + + @pytest.mark.parametrize( + ("order", "shape"), + [ + ("C", (3,)), + ("C", (3, 5)), + ("C", (0,)), + ("C", (1, 3)), + ("C", (3, 1)), + ], + ids=str, + ) + @testing.numpy_cupy_equal() + def test_view_flags_smaller(self, xp, order, shape): + a = xp.zeros(shape, numpy.int32, order) + b = a.view(numpy.int16) + return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + + @pytest.mark.parametrize( + ("order", "shape"), + [ + ("F", (3, 5)), + ], + ids=str, + ) + @testing.with_requires("numpy>=1.23") + def test_view_flags_smaller_invalid(self, order, shape): + for xp in (numpy, cupy): + a = xp.zeros(shape, numpy.int32, order) + with pytest.raises(ValueError): + a.view(numpy.int16) + + @pytest.mark.parametrize( + ("order", "shape"), + [ + ("C", (6,)), + ("C", (3, 10)), + ("C", (0,)), + ("C", (1, 6)), + ("C", (3, 2)), + ], + ids=str, + ) + @testing.numpy_cupy_equal() + def test_view_flags_larger(self, xp, order, shape): + a = xp.zeros(shape, numpy.int16, order) + b = a.view(numpy.int32) + return b.flags.c_contiguous, b.flags.f_contiguous, b.flags.owndata + + @pytest.mark.parametrize( + ("order", "shape"), + [ + ("F", (6, 5)), + ("F", (2, 3)), + ], + ids=str, + ) + @testing.with_requires("numpy>=1.23") + def test_view_flags_larger_invalid(self, order, shape): + for xp in (numpy, cupy): + a = xp.zeros(shape, numpy.int16, order) + with pytest.raises(ValueError): + a.view(numpy.int32) + + @testing.with_requires("numpy>=1.23") + @testing.numpy_cupy_array_equal() + def test_view_smaller_dtype_multiple(self, xp): + # x is non-contiguous + x = xp.arange(10, dtype=xp.int32)[::2] + with pytest.raises(ValueError): + x.view(xp.int16) + return x[:, xp.newaxis].view(xp.int16) + + @testing.with_requires("numpy>=1.23") + @testing.numpy_cupy_array_equal() + def test_view_smaller_dtype_multiple2(self, xp): + # x is non-contiguous, and stride[-1] != 0 + x = xp.ones((3, 4), xp.int32)[:, :1:2] + return x.view(xp.int16) + + @testing.with_requires("numpy>=1.23") + @testing.numpy_cupy_array_equal() + def test_view_larger_dtype_multiple(self, xp): + # x is non-contiguous in the first dimension, contiguous in the last + x = xp.arange(20, dtype=xp.int16).reshape(10, 2)[::2, :] + return x.view(xp.int32) + + @testing.with_requires("numpy>=1.23") + @testing.numpy_cupy_array_equal() + def test_view_non_c_contiguous(self, xp): + # x is contiguous in axis=-1, but not C-contiguous in other axes + x = ( + xp.arange(2 * 3 * 4, dtype=xp.int8) + .reshape(2, 3, 4) + .transpose(1, 0, 2) + ) + return x.view(xp.int16) + + @testing.numpy_cupy_array_equal() + def test_view_larger_dtype_zero_sized(self, xp): + x = xp.ones((3, 20), xp.int16)[:0, ::2] + return x.view(xp.int32) + + +class TestArrayCopy: + @testing.for_orders("CF") + @testing.for_dtypes( + [numpy.int16, numpy.int64, numpy.float16, numpy.float64] + ) + @testing.numpy_cupy_array_equal() + def test_isinstance_numpy_copy(self, xp, dtype, order): + a = numpy.arange(100, dtype=dtype).reshape(10, 10, order=order) + b = xp.empty(a.shape, dtype=dtype, order=order) + b[:] = a + return b + + @pytest.mark.skip("Doesn't raise ValueError in numpy") + def test_isinstance_numpy_copy_wrong_dtype(self): + a = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + b = cupy.empty(a.shape, dtype=numpy.int32) + with pytest.raises(ValueError): + b[:] = a + + def test_isinstance_numpy_copy_wrong_shape(self): + for xp in (numpy, cupy): + a = numpy.arange(100, dtype=numpy.float32).reshape(10, 10) + b = cupy.empty(100, dtype=a.dtype) + with pytest.raises(ValueError): + b[:] = a + + @testing.numpy_cupy_array_equal() + def test_isinstance_numpy_copy_not_slice(self, xp): + a = xp.arange(5, dtype=numpy.float32) + a[a < 3] = 0 + return a + + @pytest.mark.skip("copy from host to device is allowed") + def test_copy_host_to_device_view(self): + dev = cupy.empty((10, 10), dtype=numpy.float32)[2:5, 1:8] + host = numpy.arange(3 * 7, dtype=numpy.float32).reshape(3, 7) + with pytest.raises(ValueError): + dev[:] = host + + +class TestArrayFlatten: @testing.numpy_cupy_array_equal() def test_flatten(self, xp): a = testing.shaped_arange((2, 3, 4), xp) @@ -67,10 +236,32 @@ def test_flatten_copied(self, xp): return b @testing.numpy_cupy_array_equal() - def test_transposed_flatten(self, xp): + def test_flatten_transposed(self, xp): a = testing.shaped_arange((2, 3, 4), xp).transpose(2, 0, 1) return a.flatten() + @testing.for_orders("CFAK") + @testing.numpy_cupy_array_equal() + def test_flatten_order(self, xp, order): + a = testing.shaped_arange((2, 3, 4), xp) + return a.flatten(order) + + @testing.for_orders("CFAK") + @testing.numpy_cupy_array_equal() + def test_flatten_order_copied(self, xp, order): + a = testing.shaped_arange((4,), xp) + b = a.flatten(order=order) + a[:] = 1 + return b + + @testing.for_orders("CFAK") + @testing.numpy_cupy_array_equal() + def test_flatten_order_transposed(self, xp, order): + a = testing.shaped_arange((2, 3, 4), xp).transpose(2, 0, 1) + return a.flatten(order=order) + + +class TestArrayFill: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_fill(self, xp, dtype): @@ -78,18 +269,32 @@ def test_fill(self, xp, dtype): a.fill(1) return a - @testing.for_all_dtypes() - @testing.numpy_cupy_array_equal() - def test_fill_with_numpy_scalar_ndarray(self, xp, dtype): - a = testing.shaped_arange((2, 3, 4), xp, dtype) - a.fill(numpy.ones((), dtype=dtype)) + @testing.with_requires("numpy>=1.24.0") + @testing.for_all_dtypes_combination(("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=numpy.ComplexWarning) + def test_fill_with_numpy_scalar_ndarray(self, xp, dtype1, dtype2): + a = testing.shaped_arange((2, 3, 4), xp, dtype1) + a.fill(numpy.ones((), dtype=dtype2)) + return a + + @testing.with_requires("numpy>=1.24.0") + @testing.for_all_dtypes_combination(("dtype1", "dtype2")) + @testing.numpy_cupy_array_equal(accept_error=numpy.ComplexWarning) + def test_fill_with_cupy_scalar_ndarray(self, xp, dtype1, dtype2): + a = testing.shaped_arange((2, 3, 4), xp, dtype1) + b = xp.ones((), dtype=dtype2) + a.fill(b) return a + @pytest.mark.skip( + "it's allowed to broadcast dpnp array while filling, no exception then" + ) @testing.for_all_dtypes() - def test_fill_with_numpy_nonscalar_ndarray(self, dtype): + def test_fill_with_nonscalar_ndarray(self, dtype): a = testing.shaped_arange((2, 3, 4), cupy, dtype) - with self.assertRaises(ValueError): - a.fill(numpy.ones((1,), dtype=dtype)) + for xp in (numpy, cupy): + with pytest.raises(ValueError): + a.fill(xp.ones((1,), dtype=dtype)) @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() @@ -99,6 +304,8 @@ def test_transposed_fill(self, xp, dtype): b.fill(1) return b + +class TestArrayAsType: @testing.for_orders(["C", "F", "A", "K", None]) @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) @testing.numpy_cupy_array_equal() @@ -106,6 +313,13 @@ def test_astype(self, xp, src_dtype, dst_dtype, order): a = testing.shaped_arange((2, 3, 4), xp, src_dtype) return astype_without_warning(a, dst_dtype, order=order) + @testing.for_orders(["C", "F", "A", "K", None]) + @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) + @testing.numpy_cupy_array_equal() + def test_astype_empty(self, xp, src_dtype, dst_dtype, order): + a = testing.shaped_arange((2, 0, 4), xp, src_dtype) + return astype_without_warning(a, dst_dtype, order=order) + @testing.for_orders("CFAK") @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) def test_astype_type(self, src_dtype, dst_dtype, order): @@ -113,14 +327,14 @@ def test_astype_type(self, src_dtype, dst_dtype, order): b = astype_without_warning(a, dst_dtype, order=order) a_cpu = testing.shaped_arange((2, 3, 4), numpy, src_dtype) b_cpu = astype_without_warning(a_cpu, dst_dtype, order=order) - self.assertEqual(b.dtype.type, b_cpu.dtype.type) + assert b.dtype.type == b_cpu.dtype.type @testing.for_orders("CAK") @testing.for_all_dtypes() def test_astype_type_c_contiguous_no_copy(self, dtype, order): a = testing.shaped_arange((2, 3, 4), cupy, dtype) b = a.astype(dtype, order=order, copy=False) - self.assertTrue(b is a) + assert b is a @testing.for_orders("FAK") @testing.for_all_dtypes() @@ -128,101 +342,68 @@ def test_astype_type_f_contiguous_no_copy(self, dtype, order): a = testing.shaped_arange((2, 3, 4), cupy, dtype) a = cupy.asfortranarray(a) b = a.astype(dtype, order=order, copy=False) - self.assertTrue(b is a) + assert b is a @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_equal() def test_astype_strides(self, xp, src_dtype, dst_dtype): - src = xp.empty((1, 2, 3), dtype=src_dtype) - return numpy.array( - astype_without_warning(src, dst_dtype, order="K").strides - ) + src = testing.shaped_arange((1, 2, 3), xp, dtype=src_dtype) + dst = astype_without_warning(src, dst_dtype, order="K") + return get_strides(xp, dst) @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_equal() def test_astype_strides_negative(self, xp, src_dtype, dst_dtype): - src = xp.empty((2, 3), dtype=src_dtype)[::-1, :] - return numpy.array( - astype_without_warning(src, dst_dtype, order="K").strides - ) + src = testing.shaped_arange((2, 3), xp, dtype=src_dtype) + src = src[::-1, :] + dst = astype_without_warning(src, dst_dtype, order="K") + return tuple(abs(x) for x in get_strides(xp, dst)) @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_equal() def test_astype_strides_swapped(self, xp, src_dtype, dst_dtype): - src = xp.swapaxes(xp.empty((2, 3, 4), dtype=src_dtype), 1, 0) - return numpy.array( - astype_without_warning(src, dst_dtype, order="K").strides - ) + src = testing.shaped_arange((2, 3, 4), xp, dtype=src_dtype) + src = xp.swapaxes(src, 1, 0) + dst = astype_without_warning(src, dst_dtype, order="K") + return get_strides(xp, dst) @testing.for_all_dtypes_combination(("src_dtype", "dst_dtype")) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_equal() def test_astype_strides_broadcast(self, xp, src_dtype, dst_dtype): src1 = testing.shaped_arange((2, 3, 2), xp, dtype=src_dtype) src2 = testing.shaped_arange((2,), xp, dtype=src_dtype) src, _ = xp.broadcast_arrays(src1, src2) dst = astype_without_warning(src, dst_dtype, order="K") - strides = dst.strides - if xp is numpy: - strides = tuple(x // dst.itemsize for x in strides) - return strides + return get_strides(xp, dst) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") + @pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") + @testing.numpy_cupy_array_equal() + def test_astype_boolean_view(self, xp): + # See #4354 + a = xp.array([0, 1, 2], dtype=numpy.int8).view(dtype=numpy.bool_) + return a.astype(numpy.int8) + + +@pytest.mark.usefixtures("allow_fall_back_on_numpy") +class TestArrayDiagonal: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_diagonal1(self, xp, dtype): a = testing.shaped_arange((3, 4, 5), xp, dtype) return a.diagonal(1, 2, 0) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_diagonal2(self, xp, dtype): a = testing.shaped_arange((3, 4, 5), xp, dtype) return a.diagonal(-1, 2, 0) - # @unittest.skipUnless(util.ENABLE_SLICE_COPY, 'Special copy disabled') - @testing.for_orders("CF") - @testing.for_dtypes( - [numpy.int16, numpy.int64, numpy.float16, numpy.float64] - ) - @testing.numpy_cupy_array_equal() - def test_isinstance_numpy_copy(self, xp, dtype, order): - a = numpy.arange(100, dtype=dtype).reshape(10, 10, order=order) - b = xp.empty(a.shape, dtype=dtype, order=order) - b[:] = a - return b - - # @unittest.skipUnless(util.ENABLE_SLICE_COPY, 'Special copy disabled') - def test_isinstance_numpy_copy_wrong_dtype(self): - for xp in (numpy, cupy): - a = numpy.arange(100, dtype=numpy.float64).reshape(10, 10) - b = cupy.empty(a.shape, dtype=numpy.int32) - with pytest.raises(ValueError): - b[:] = a - - # @unittest.skipUnless(util.ENABLE_SLICE_COPY, 'Special copy disabled') - def test_isinstance_numpy_copy_wrong_shape(self): - for xp in (numpy, cupy): - a = numpy.arange(100, dtype=numpy.float64).reshape(10, 10) - b = cupy.empty(100, dtype=a.dtype) - with pytest.raises(ValueError): - b[:] = a - - # @unittest.skipUnless(util.ENABLE_SLICE_COPY, 'Special copy disabled') - @testing.numpy_cupy_array_equal() - def test_isinstance_numpy_copy_not_slice(self, xp): - a = xp.arange(5, dtype=numpy.float64) - a[a < 3] = 0 - return a - @testing.parameterize( {"src_order": "C"}, {"src_order": "F"}, ) -@testing.gpu -class TestNumPyArrayCopyView(unittest.TestCase): - # @unittest.skipUnless(util.ENABLE_SLICE_COPY, 'Special copy disabled') +class TestNumPyArrayCopyView: @testing.for_orders("CF") @testing.for_dtypes( [numpy.int16, numpy.int64, numpy.float16, numpy.float64] @@ -234,3 +415,66 @@ def test_isinstance_numpy_view_copy_f(self, xp, dtype, order): b = xp.empty(a.shape, dtype=dtype, order=order) b[:] = a return b + + +class C_cp(cupy.ndarray): + def __new__(cls, *args, info=None, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + obj.info = info + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.info = getattr(obj, "info", None) + + +class C_np(numpy.ndarray): + def __new__(cls, *args, info=None, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + obj.info = info + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.info = getattr(obj, "info", None) + + +@pytest.mark.skip("'dpnp_array' object has no attribute 'view' yet") +class TestSubclassArrayView: + def test_view_casting(self): + for xp, C in [(numpy, C_np), (cupy, C_cp)]: + a = xp.arange(5, dtype="i").view("F") + assert type(a) is xp.ndarray + assert a.dtype == xp.float32 + + a = xp.arange(5, dtype="i").view(dtype="F") + assert type(a) is xp.ndarray + assert a.dtype == xp.float32 + + with pytest.raises(TypeError): + xp.arange(5, dtype="i").view("F", dtype="F") + + a = xp.arange(5, dtype="i").view(C) + assert type(a) is C + assert a.dtype == xp.int32 + assert a.info is None + + a = xp.arange(5, dtype="i").view(type=C) + assert type(a) is C + assert a.dtype == xp.int32 + assert a.info is None + + # When an instance of ndarray's subclass is supplied to `dtype`, + # view() interprets it as if it is supplied to `type` + a = xp.arange(5, dtype="i").view(dtype=C) + assert type(a) is C + assert a.dtype == xp.int32 + assert a.info is None + + with pytest.raises(TypeError): + xp.arange(5).view("F", C, type=C) + + with pytest.raises(ValueError): + cupy.arange(5).view(type=numpy.ndarray) diff --git a/tests/third_party/cupy/test_type_routines.py b/tests/third_party/cupy/test_type_routines.py new file mode 100644 index 00000000000..6a274158bcd --- /dev/null +++ b/tests/third_party/cupy/test_type_routines.py @@ -0,0 +1,102 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.third_party.cupy import testing + + +def _generate_type_routines_input(xp, dtype, obj_type): + dtype = numpy.dtype(dtype) + if obj_type == "dtype": + return dtype + if obj_type == "specifier": + return str(dtype) + if obj_type == "scalar": + return dtype.type(3) + if obj_type == "array": + return xp.zeros(3, dtype=dtype) + if obj_type == "primitive": + return type(dtype.type(3).tolist()) + assert False + + +@testing.parameterize( + *testing.product( + { + "obj_type": ["dtype", "specifier", "scalar", "array", "primitive"], + } + ) +) +class TestCanCast(unittest.TestCase): + @testing.for_all_dtypes_combination(names=("from_dtype", "to_dtype")) + @testing.numpy_cupy_equal() + def test_can_cast(self, xp, from_dtype, to_dtype): + if self.obj_type == "scalar": + pytest.skip("to be aligned with NEP-50") + + from_obj = _generate_type_routines_input(xp, from_dtype, self.obj_type) + + ret = xp.can_cast(from_obj, to_dtype) + assert isinstance(ret, bool) + return ret + + +@pytest.mark.skip("dpnp.common_type() is not implemented yet") +class TestCommonType(unittest.TestCase): + @testing.numpy_cupy_equal() + def test_common_type_empty(self, xp): + ret = xp.common_type() + assert type(ret) == type + return ret + + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_equal() + def test_common_type_single_argument(self, xp, dtype): + array = _generate_type_routines_input(xp, dtype, "array") + ret = xp.common_type(array) + assert type(ret) == type + return ret + + @testing.for_all_dtypes_combination( + names=("dtype1", "dtype2"), no_bool=True + ) + @testing.numpy_cupy_equal() + def test_common_type_two_arguments(self, xp, dtype1, dtype2): + array1 = _generate_type_routines_input(xp, dtype1, "array") + array2 = _generate_type_routines_input(xp, dtype2, "array") + ret = xp.common_type(array1, array2) + assert type(ret) == type + return ret + + @testing.for_all_dtypes() + def test_common_type_bool(self, dtype): + for xp in (numpy, cupy): + array1 = _generate_type_routines_input(xp, dtype, "array") + array2 = _generate_type_routines_input(xp, "bool_", "array") + with pytest.raises(TypeError): + xp.common_type(array1, array2) + + +@testing.parameterize( + *testing.product( + { + "obj_type1": ["dtype", "specifier", "scalar", "array", "primitive"], + "obj_type2": ["dtype", "specifier", "scalar", "array", "primitive"], + } + ) +) +class TestResultType(unittest.TestCase): + @testing.for_all_dtypes_combination(names=("dtype1", "dtype2")) + @testing.numpy_cupy_equal() + def test_result_type(self, xp, dtype1, dtype2): + if "scalar" in {self.obj_type1, self.obj_type2}: + pytest.skip("to be aligned with NEP-50") + + input1 = _generate_type_routines_input(xp, dtype1, self.obj_type1) + + input2 = _generate_type_routines_input(xp, dtype2, self.obj_type2) + ret = xp.result_type(input1, input2) + assert isinstance(ret, numpy.dtype) + return ret