From 304c9868f780e0e5311d59d10873b03e597714be Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 6 Sep 2023 08:55:34 -0600 Subject: [PATCH] add_real_imag --- dpnp/dparray.pyx | 6 +- dpnp/dpnp_algo/dpnp_elementwise_common.py | 82 +++++++++++ dpnp/dpnp_array.py | 51 ++++++- dpnp/dpnp_iface_mathematical.py | 157 ++++++++++++++++++++-- tests/skipped_tests.tbl | 13 -- tests/skipped_tests_gpu.tbl | 13 -- tests/test_mathematical.py | 24 ++++ tests/test_strides.py | 26 ++++ tests/test_sycl_queue.py | 6 + tests/test_usm_type.py | 6 + 10 files changed, 339 insertions(+), 45 deletions(-) diff --git a/dpnp/dparray.pyx b/dpnp/dparray.pyx index 11dcf0784b7..e8d9dd2e345 100644 --- a/dpnp/dparray.pyx +++ b/dpnp/dparray.pyx @@ -30,7 +30,7 @@ """Module DPArray This module contains Array class represents multi-dimensional array -using USB interface for an Intel GPU device. +using USM interface for an Intel GPU device. """ @@ -51,7 +51,7 @@ from dpnp.dpnp_iface import get_dpnp_descriptor as iface_get_dpnp_descriptor from dpnp.dpnp_iface import prod as iface_prod from dpnp.dpnp_iface import sum as iface_sum -# It's prohibeted to use 'import *' from 'dpnp.dpnp_iface_arraycreation' module here, +# It's prohibited to use 'import *' from 'dpnp.dpnp_iface_arraycreation' module here, # because module has 'array' function, but cython has already imported 'array' by default. # It would cause import collision. Thus instead import each function explicitly. from dpnp.dpnp_iface_arraycreation import ( @@ -196,7 +196,7 @@ cdef class dparray: """Multi-dimensional array using USM interface for an Intel GPU device. This class implements a subset of methods of :class:`numpy.ndarray`. - The difference is that this class allocates the array content useing + The difference is that this class allocates the array content using USM interface on the current GPU device. Args: diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index ffc1f20065c..926caeec73c 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -63,6 +63,7 @@ "dpnp_floor_divide", "dpnp_greater", "dpnp_greater_equal", + "dpnp_imag", "dpnp_invert", "dpnp_isfinite", "dpnp_isinf", @@ -80,6 +81,7 @@ "dpnp_not_equal", "dpnp_power", "dpnp_proj", + "dpnp_real", "dpnp_remainder", "dpnp_right_shift", "dpnp_round", @@ -1259,6 +1261,46 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"): return dpnp_array._create_from_usm_ndarray(res_usm) +_imag_docstring = """ +imag(x, out=None, order="K") + +Computes imaginary part of each element `x_i` for input array `x`. + +Args: + x (dpnp.ndarray): + Input array, expected to have numeric data type. + out ({None, dpnp.ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". +Returns: + dpnp.ndarray: + An array containing the element-wise imaginary component of input. + If the input is a real-valued data type, the returned array has + the same datat type. If the input is a complex floating-point + data type, the returned array has a floating-point data type + with the same floating-point precision as complex input. +""" + + +imag_func = UnaryElementwiseFunc( + "imag", ti._imag_result_type, ti._imag, _imag_docstring +) + + +def dpnp_imag(x, out=None, order="K"): + """Invokes imag() from dpctl.tensor implementation for imag() function.""" + + # dpctl.tensor only works with usm_ndarray + x1_usm = dpnp.get_usm_ndarray(x) + out_usm = None if out is None else dpnp.get_usm_ndarray(out) + + res_usm = imag_func(x1_usm, out=out_usm, order=order) + return dpnp_array._create_from_usm_ndarray(res_usm) + + _invert_docstring = """ invert(x, out=None, order='K') @@ -2021,6 +2063,46 @@ def dpnp_proj(x, out=None, order="K"): return dpnp_array._create_from_usm_ndarray(res_usm) +_real_docstring = """ +real(x, out=None, order="K") + +Computes real part of each element `x_i` for input array `x`. + +Args: + x (dpnp.ndarray): + Input array, expected to have numeric data type. + out ({None, dpnp.ndarray}, optional): + Output array to populate. + Array have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the newly output array, if parameter `out` is `None`. + Default: "K". +Returns: + dpnp.ndarray: + An array containing the element-wise real component of input. + If the input is a real-valued data type, the returned array has + the same datat type. If the input is a complex floating-point + data type, the returned array has a floating-point data type + with the same floating-point precision as complex input. +""" + + +real_func = UnaryElementwiseFunc( + "real", ti._real_result_type, ti._real, _real_docstring +) + + +def dpnp_real(x, out=None, order="K"): + """Invokes real() from dpctl.tensor implementation for real() function.""" + + # dpctl.tensor only works with usm_ndarray + x1_usm = dpnp.get_usm_ndarray(x) + out_usm = None if out is None else dpnp.get_usm_ndarray(out) + + res_usm = real_func(x1_usm, out=out_usm, order=order) + return dpnp_array._create_from_usm_ndarray(res_usm) + + _remainder_docstring_ = """ remainder(x1, x2, out=None, order='K') Calculates the remainder of division for each element `x1_i` of the input array diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 24bc3a03a5e..fae6e98af91 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -819,7 +819,29 @@ def flatten(self, order="C"): return new_arr # 'getfield', - # 'imag', + + @property + def imag(self): + """ + The imaginary part of the array. + + For full documentation refer to :obj:`numpy.ndarray.imag`. + + Examples + -------- + >>> import dpnp as np + >>> x = np.sqrt(np.array([1+0j, 0+1j])) + >>> x.imag + array([ 0. , 0.70710678]) + """ + return dpnp.imag(self) + + @imag.setter + def imag(self, value): + if dpnp.issubsctype(self.dtype, dpnp.complexfloating): + dpnp.copyto(self._array_obj.imag, value) + else: + raise TypeError("dpnp.ndarray does not have imaginary part to set") def item(self, id=None): """ @@ -975,7 +997,30 @@ def put(self, indices, vals, /, *, axis=None, mode="wrap"): return dpnp.put(self, indices, vals, axis=axis, mode=mode) # 'ravel', - # 'real', + + @property + def real(self): + """ + The real part of the array. + + For full documentation refer to :obj:`numpy.ndarray.real`. + + Examples + -------- + >>> import dpnp as np + >>> x = np.sqrt(np.array([1+0j, 0+1j])) + >>> x.real + array([ 1. , 0.70710678]) + """ + if dpnp.issubsctype(self.dtype, dpnp.complexfloating): + return dpnp.real(self) + else: + return self + + @real.setter + def real(self, value): + dpnp.copyto(self._array_obj.real, value) + # 'repeat', def reshape(self, *sh, **kwargs): @@ -1050,7 +1095,7 @@ def shape(self, newshape): """ - dpnp.reshape(self, newshape=newshape) + self._array_obj.shape = newshape @property def size(self): diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 4fad31bf9ee..4138ee4f309 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -56,10 +56,12 @@ dpnp_divide, dpnp_floor, dpnp_floor_divide, + dpnp_imag, dpnp_multiply, dpnp_negative, dpnp_power, dpnp_proj, + dpnp_real, dpnp_remainder, dpnp_round, dpnp_sign, @@ -92,6 +94,7 @@ "fmin", "fmod", "gradient", + "imag", "maximum", "minimum", "mod", @@ -105,6 +108,7 @@ "power", "prod", "proj", + "real", "remainder", "rint", "round", @@ -1162,6 +1166,68 @@ def gradient(x1, *varargs, **kwargs): return call_origin(numpy.gradient, x1, *varargs, **kwargs) +def imag( + x, + /, + out=None, + *, + order="K", + where=True, + dtype=None, + subok=True, + **kwargs, +): + """ + Return the imaginary part of the complex argument. + + For full documentation refer to :obj:`numpy.imag`. + + Returns + ------- + out : dpnp.ndarray + The imaginary component of the complex argument. + + Limitations + ----------- + Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. + Parameters `where`, `dtype` and `subok` are supported with their default values. + Keyword argument `kwargs` is currently unsupported. + Otherwise the function will be executed sequentially on CPU. + Input array data types are limited by supported DPNP :ref:`Data types`. + + See Also + -------- + :obj:`dpnp.real` : Return the real part of the complex argument. + :obj:`dpnp.conj` : Return the complex conjugate, element-wise. + :obj:`dpnp.conjugate` : Return the complex conjugate, element-wise. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array([1+2j, 3+4j, 5+6j]) + >>> a.imag + array([2., 4., 6.]) + >>> a.imag = np.array([8, 10, 12]) + >>> a + array([1.+8.j, 3.+10.j, 5.+12.j]) + >>> np.imag(np.array(1 + 1j)) + array(1.0) + + """ + + return check_nd_call_func( + numpy.imag, + dpnp_imag, + x, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, + ) + + def maximum( x1, x2, /, out=None, *, where=True, dtype=None, subok=True, **kwargs ): @@ -1949,7 +2015,7 @@ def proj( ) -def rint( +def real( x, /, out=None, @@ -1961,14 +2027,14 @@ def rint( **kwargs, ): """ - Round elements of the array to the nearest integer. + Return the real part of the complex argument. - For full documentation refer to :obj:`numpy.rint`. + For full documentation refer to :obj:`numpy.real`. Returns ------- out : dpnp.ndarray - The rounded value of elements of the array to the nearest integer. + The real component of the complex argument. Limitations ----------- @@ -1980,23 +2046,30 @@ def rint( See Also -------- - :obj:`dpnp.round` : Evenly round to the given number of decimals. - :obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise. - :obj:`dpnp.floor` : Return the floor of the input, element-wise. - :obj:`dpnp.trunc` : Return the truncated value of the input, element-wise. + :obj:`dpnp.imag` : Return the imaginary part of the complex argument. + :obj:`dpnp.conj` : Return the complex conjugate, element-wise. + :obj:`dpnp.conjugate` : Return the complex conjugate, element-wise. Examples -------- >>> import dpnp as np - >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) - >>> np.rint(a) - array([-2., -2., -0., 0., 2., 2., 2.]) + >>> a = np.array([1+2j, 3+4j, 5+6j]) + >>> a.real + array([1., 3., 5.]) + >>> a.real = 9 + >>> a + array([9.+2.j, 9.+4.j, 9.+6.j]) + >>> a.real = np.array([9, 8, 7]) + >>> a + array([9.+2.j, 8.+4.j, 7.+6.j]) + >>> np.real(np.array(1 + 1j)) + array(1.) """ return check_nd_call_func( - numpy.rint, - dpnp_round, + numpy.real, + dpnp_real, x, out=out, where=where, @@ -2077,6 +2150,64 @@ def remainder( ) +def rint( + x, + /, + out=None, + *, + order="K", + where=True, + dtype=None, + subok=True, + **kwargs, +): + """ + Round elements of the array to the nearest integer. + + For full documentation refer to :obj:`numpy.rint`. + + Returns + ------- + out : dpnp.ndarray + The rounded value of elements of the array to the nearest integer. + + Limitations + ----------- + Parameter `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. + Parameters `where`, `dtype` and `subok` are supported with their default values. + Keyword argument `kwargs` is currently unsupported. + Otherwise the function will be executed sequentially on CPU. + Input array data types are limited by supported DPNP :ref:`Data types`. + + See Also + -------- + :obj:`dpnp.round` : Evenly round to the given number of decimals. + :obj:`dpnp.ceil` : Compute the ceiling of the input, element-wise. + :obj:`dpnp.floor` : Return the floor of the input, element-wise. + :obj:`dpnp.trunc` : Return the truncated value of the input, element-wise. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) + >>> np.rint(a) + array([-2., -2., -0., 0., 2., 2., 2.]) + + """ + + return check_nd_call_func( + numpy.rint, + dpnp_round, + x, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, + ) + + def round(x, decimals=0, out=None): """ Evenly round to the given number of decimals. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 128b7f7f5da..5cfdc6532f7 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -127,21 +127,8 @@ tests/test_umath.py::test_umaths[('spacing', 'f')] tests/test_umath.py::test_umaths[('spacing', 'd')] tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_raise -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_zero_dim tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_0_{shape=()}::test_item tests/third_party/cupy/core_tests/test_ndarray_conversion.py::TestNdarrayToBytes_param_1_{shape=(1,)}::test_item diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 081f0b72350..3ab40ed4d04 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -249,21 +249,8 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp. tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)] tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestAngle::test_angle -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_raise -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_setter_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_imag_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_inplace -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter_non_contiguous -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_setter_zero_dim -tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestRealImag::test_real_zero_dim tests/third_party/cupy/core_tests/test_ndarray_complex_ops.py::TestScalarConversion::test_scalar_conversion tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index adb556fb79c..336aace0388 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -466,6 +466,30 @@ def test_signbit(data, dtype): assert_allclose(result, expected) +@pytest.mark.parametrize( + "data", + [complex(-1, -4), complex(-1, 2), complex(3, -7), complex(4, 12)], + ids=[ + "complex(-1, -4)", + "complex(-1, 2)", + "complex(3, -7)", + "complex(4, 12)", + ], +) +@pytest.mark.parametrize("dtype", get_complex_dtypes()) +def test_real_imag(data, dtype): + np_a = numpy.array(data, dtype=dtype) + dpnp_a = dpnp.array(data, dtype=dtype) + + result = dpnp.real(dpnp_a) + expected = numpy.real(np_a) + assert_allclose(result, expected) + + result = dpnp.imag(dpnp_a) + expected = numpy.imag(np_a) + assert_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) def test_projection_infinity(dtype): X = [ diff --git a/tests/test_strides.py b/tests/test_strides.py index 1ddfcd14eb5..84e794cd053 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -97,6 +97,32 @@ def test_strides_1arg(func_name, dtype, shape): assert_allclose(result, expected, rtol=1e-06) +@pytest.mark.parametrize( + "func_name", + [ + "conjugate", + "imag", + "real", + ], +) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) +@pytest.mark.parametrize("shape", [(10,)], ids=["(10,)"]) +def test_strides_1arg_complex(func_name, dtype, shape): + a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape) + b = a[::2] + + dpa = dpnp.reshape(dpnp.arange(numpy.prod(shape), dtype=dtype), shape) + dpb = dpa[::2] + + dpnp_func = _getattr(dpnp, func_name) + result = dpnp_func(dpb) + + numpy_func = _getattr(numpy, func_name) + expected = numpy_func(b) + + assert_allclose(result, expected, rtol=1e-06) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("shape", [(10,)], ids=["(10,)"]) def test_strides_erf(dtype, shape): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index f5b0248f33c..48a562cc798 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -245,12 +245,18 @@ def test_meshgrid(device_x, device_y): pytest.param("fabs", [-1.2, 1.2]), pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), pytest.param("gradient", [1.0, 2.0, 4.0, 7.0, 11.0, 16.0]), + pytest.param( + "imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] + ), pytest.param("nancumprod", [1.0, dpnp.nan]), pytest.param("nancumsum", [1.0, dpnp.nan]), pytest.param("nanprod", [1.0, dpnp.nan]), pytest.param("nansum", [1.0, dpnp.nan]), pytest.param("negative", [1.0, 0.0, -1.0]), pytest.param("prod", [1.0, 2.0]), + pytest.param( + "real", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] + ), pytest.param("sign", [-5.0, 0.0, 4.5]), pytest.param("signbit", [-5.0, 0.0, 4.5]), pytest.param( diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index a935c699f77..79125a5376b 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -295,8 +295,14 @@ def test_meshgrid(usm_type_x, usm_type_y): ), pytest.param("cosh", [-5.0, -3.5, 0.0, 3.5, 5.0]), pytest.param("floor", [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]), + pytest.param( + "imag", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] + ), pytest.param("negative", [1.0, 0.0, -1.0]), pytest.param("proj", [complex(1.0, 2.0), complex(dp.inf, -1.0)]), + pytest.param( + "real", [complex(1.0, 2.0), complex(3.0, 4.0), complex(5.0, 6.0)] + ), pytest.param("sign", [-5.0, 0.0, 4.5]), pytest.param("signbit", [-5.0, 0.0, 4.5]), pytest.param(