From 22b1c730d980e9f062889a6b01ca0c930cea8606 Mon Sep 17 00:00:00 2001 From: vtavana <120411540+vtavana@users.noreply.github.com> Date: Sun, 15 Sep 2024 03:56:46 -0500 Subject: [PATCH] Implement `dpnp.rot90` and `dpnp.resize` (#2030) * add rot90 and resize * fix docstring * address comments --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- dpnp/dpnp_iface_manipulation.py | 216 +++++++++++++++++- tests/test_manipulation.py | 118 +++++++++- tests/test_sycl_queue.py | 15 ++ tests/test_usm_type.py | 10 + .../manipulation_tests/test_add_remove.py | 6 +- .../cupy/manipulation_tests/test_rearrange.py | 1 - 6 files changed, 350 insertions(+), 16 deletions(-) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index ee3e987fb6a..6346fe24326 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -39,6 +39,7 @@ import math +import operator import dpctl.tensor as dpt import numpy @@ -77,9 +78,11 @@ "ravel", "repeat", "reshape", + "resize", "result_type", "roll", "rollaxis", + "rot90", "row_stack", "shape", "size", @@ -1522,6 +1525,7 @@ def fliplr(m): -------- :obj:`dpnp.flipud` : Flip an array vertically (axis=0). :obj:`dpnp.flip` : Flip array in one or more dimensions. + :obj:`dpnp.rot90` : Rotate array counterclockwise. Examples -------- @@ -1572,6 +1576,7 @@ def flipud(m): -------- :obj:`dpnp.fliplr` : Flip array in the left/right direction. :obj:`dpnp.flip` : Flip array in one or more dimensions. + :obj:`dpnp.rot90` : Rotate array counterclockwise. Examples -------- @@ -1985,7 +1990,8 @@ def reshape(a, /, newshape, order="C", copy=None): If ``False``, the result array can never be a copy and a ValueError exception will be raised in case the copy is necessary. If ``None``, the result array will reuse existing memory buffer of `a` - if possible and copy otherwise. Default: None. + if possible and copy otherwise. + Default: ``None``. Returns ------- @@ -2004,14 +2010,14 @@ def reshape(a, /, newshape, order="C", copy=None): Examples -------- - >>> import dpnp as dp - >>> a = dp.array([[1, 2, 3], [4, 5, 6]]) - >>> dp.reshape(a, 6) + >>> import dpnp as np + >>> a = np.array([[1, 2, 3], [4, 5, 6]]) + >>> np.reshape(a, 6) array([1, 2, 3, 4, 5, 6]) - >>> dp.reshape(a, 6, order='F') + >>> np.reshape(a, 6, order='F') array([1, 4, 2, 5, 3, 6]) - >>> dp.reshape(a, (3, -1)) # the unspecified value is inferred to be 2 + >>> np.reshape(a, (3, -1)) # the unspecified value is inferred to be 2 array([[1, 2], [3, 4], [5, 6]]) @@ -2031,6 +2037,91 @@ def reshape(a, /, newshape, order="C", copy=None): return dpnp_array._create_from_usm_ndarray(usm_res) +def resize(a, new_shape): + """ + Return a new array with the specified shape. + + If the new array is larger than the original array, then the new array is + filled with repeated copies of `a`. Note that this behavior is different + from ``a.resize(new_shape)`` which fills with zeros instead of repeated + copies of `a`. + + For full documentation refer to :obj:`numpy.resize`. + + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Array to be resized. + new_shape : {int, tuple or list of ints} + Shape of resized array. + + Returns + ------- + out : dpnp.ndarray + The new array is formed from the data in the old array, repeated + if necessary to fill out the required number of elements. The + data are repeated iterating over the array in C-order. + + See Also + -------- + :obj:`dpnp.ndarray.resize` : Resize an array in-place. + :obj:`dpnp.reshape` : Reshape an array without changing the total size. + :obj:`dpnp.pad` : Enlarge and pad an array. + :obj:`dpnp.repeat` : Repeat elements of an array. + + Notes + ----- + When the total size of the array does not change :obj:`dpnp.reshape` should + be used. In most other cases either indexing (to reduce the size) or + padding (to increase the size) may be a more appropriate solution. + + Warning: This functionality does **not** consider axes separately, + i.e. it does not apply interpolation/extrapolation. + It fills the return array with the required number of elements, iterating + over `a` in C-order, disregarding axes (and cycling back from the start if + the new shape is larger). This functionality is therefore not suitable to + resize images, or data where each axis represents a separate and distinct + entity. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array([[0, 1], [2, 3]]) + >>> np.resize(a, (2, 3)) + array([[0, 1, 2], + [3, 0, 1]]) + >>> np.resize(a, (1, 4)) + array([[0, 1, 2, 3]]) + >>> np.resize(a, (2, 4)) + array([[0, 1, 2, 3], + [0, 1, 2, 3]]) + + """ + + dpnp.check_supported_arrays_type(a) + if a.ndim == 0: + return dpnp.full_like(a, a, shape=new_shape) + + if isinstance(new_shape, (int, numpy.integer)): + new_shape = (new_shape,) + + new_size = 1 + for dim_length in new_shape: + if dim_length < 0: + raise ValueError("all elements of `new_shape` must be non-negative") + new_size *= dim_length + + a_size = a.size + if a_size == 0 or new_size == 0: + # First case must zero fill. The second would have repeats == 0. + return dpnp.zeros_like(a, shape=new_shape) + + repeats = -(-new_size // a_size) # ceil division + a = dpnp.concatenate((dpnp.ravel(a),) * repeats)[:new_size] + + return a.reshape(new_shape) + + def result_type(*arrays_and_dtypes): """ result_type(*arrays_and_dtypes) @@ -2052,16 +2143,16 @@ def result_type(*arrays_and_dtypes): Examples -------- - >>> import dpnp as dp - >>> a = dp.arange(3, dtype=dp.int64) - >>> b = dp.arange(7, dtype=dp.int32) - >>> dp.result_type(a, b) + >>> import dpnp as np + >>> a = np.arange(3, dtype=np.int64) + >>> b = np.arange(7, dtype=np.int32) + >>> np.result_type(a, b) dtype('int64') - >>> dp.result_type(dp.int64, dp.complex128) + >>> np.result_type(np.int64, np.complex128) dtype('complex128') - >>> dp.result_type(dp.ones(10, dtype=dp.float32), dp.float64) + >>> np.result_type(np.ones(10, dtype=np.float32), np.float64) dtype('float64') """ @@ -2200,6 +2291,107 @@ def rollaxis(x, axis, start=0): return dpnp.moveaxis(usm_array, source=axis, destination=start) +def rot90(m, k=1, axes=(0, 1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + + Rotation direction is from the first towards the second axis. + This means for a 2D array with the default `k` and `axes`, the + rotation will be counterclockwise. + + For full documentation refer to :obj:`numpy.rot90`. + + Parameters + ---------- + m : {dpnp.ndarray, usm_ndarray} + Array of two or more dimensions. + k : integer, optional + Number of times the array is rotated by 90 degrees. + Default: ``1``. + axes : (2,) array_like of ints, optional + The array is rotated in the plane defined by the axes. + Axes must be different. + Default: ``(0, 1)``. + + Returns + ------- + out : dpnp.ndarray + A rotated view of `m`. + + See Also + -------- + :obj:`dpnp.flip` : Reverse the order of elements in an array along + the given axis. + :obj:`dpnp.fliplr` : Flip an array horizontally. + :obj:`dpnp.flipud` : Flip an array vertically. + + Notes + ----- + ``rot90(m, k=1, axes=(1,0))`` is the reverse of + ``rot90(m, k=1, axes=(0,1))``. + + ``rot90(m, k=1, axes=(1,0))`` is equivalent to + ``rot90(m, k=-1, axes=(0,1))``. + + Examples + -------- + >>> import dpnp as np + >>> m = np.array([[1, 2], [3, 4]]) + >>> m + array([[1, 2], + [3, 4]]) + >>> np.rot90(m) + array([[2, 4], + [1, 3]]) + >>> np.rot90(m, 2) + array([[4, 3], + [2, 1]]) + >>> m = np.arange(8).reshape((2, 2, 2)) + >>> np.rot90(m, 1, (1, 2)) + array([[[1, 3], + [0, 2]], + [[5, 7], + [4, 6]]]) + + """ + + dpnp.check_supported_arrays_type(m) + k = operator.index(k) + + m_ndim = m.ndim + if m_ndim < 2: + raise ValueError("Input must be at least 2-d.") + + if len(axes) != 2: + raise ValueError("len(axes) must be 2.") + + if axes[0] == axes[1] or abs(axes[0] - axes[1]) == m_ndim: + raise ValueError("Axes must be different.") + + if not (-m_ndim <= axes[0] < m_ndim and -m_ndim <= axes[1] < m_ndim): + raise ValueError( + f"Axes={axes} out of range for array of ndim={m_ndim}." + ) + + k %= 4 + if k == 0: + return m[:] + if k == 2: + return dpnp.flip(dpnp.flip(m, axes[0]), axes[1]) + + axes_list = list(range(0, m_ndim)) + (axes_list[axes[0]], axes_list[axes[1]]) = ( + axes_list[axes[1]], + axes_list[axes[0]], + ) + + if k == 1: + return dpnp.transpose(dpnp.flip(m, axes[1]), axes_list) + + # k == 3 + return dpnp.flip(dpnp.transpose(m, axes_list), axes[1]) + + def shape(a): """ Return the shape of an array. diff --git a/tests/test_manipulation.py b/tests/test_manipulation.py index b9afeef7f29..b2122255d50 100644 --- a/tests/test_manipulation.py +++ b/tests/test_manipulation.py @@ -2,7 +2,7 @@ import numpy import pytest from dpctl.tensor._numpy_helper import AxisError -from numpy.testing import assert_array_equal, assert_raises +from numpy.testing import assert_array_equal, assert_equal, assert_raises import dpnp from tests.third_party.cupy import testing @@ -665,6 +665,122 @@ def test_minimum_signed_integers(self, data, dtype): assert_array_equal(result, expected) +class TestResize: + @pytest.mark.parametrize( + "data, shape", + [ + pytest.param([[1, 2], [3, 4]], (2, 4)), + pytest.param([[1, 2], [3, 4], [1, 2], [3, 4]], (4, 2)), + pytest.param([[1, 2, 3], [4, 1, 2], [3, 4, 1], [2, 3, 4]], (4, 3)), + ], + ) + def test_copies(self, data, shape): + a = numpy.array(data) + ia = dpnp.array(a) + assert_equal(dpnp.resize(ia, shape), numpy.resize(a, shape)) + + @pytest.mark.parametrize("newshape", [(2, 4), [2, 4], (10,), 10]) + def test_newshape_type(self, newshape): + a = numpy.array([[1, 2], [3, 4]]) + ia = dpnp.array(a) + assert_equal(dpnp.resize(ia, newshape), numpy.resize(a, newshape)) + + @pytest.mark.parametrize( + "data, shape", + [ + pytest.param([1, 2, 3], (2, 4)), + pytest.param([[1, 2], [3, 1], [2, 3], [1, 2]], (4, 2)), + pytest.param([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]], (4, 3)), + ], + ) + def test_repeats(self, data, shape): + a = numpy.array(data) + ia = dpnp.array(a) + assert_equal(dpnp.resize(ia, shape), numpy.resize(a, shape)) + + def test_zeroresize(self): + a = numpy.array([[1, 2], [3, 4]]) + ia = dpnp.array(a) + assert_array_equal(dpnp.resize(ia, (0,)), numpy.resize(a, (0,))) + assert_equal(a.dtype, ia.dtype) + + assert_equal(dpnp.resize(ia, (0, 2)), numpy.resize(a, (0, 2))) + assert_equal(dpnp.resize(ia, (2, 0)), numpy.resize(a, (2, 0))) + + def test_reshape_from_zero(self): + a = numpy.zeros(0, dtype=numpy.float32) + ia = dpnp.array(a) + assert_array_equal(dpnp.resize(ia, (2, 1)), numpy.resize(a, (2, 1))) + assert_equal(a.dtype, ia.dtype) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_negative_resize(self, xp): + a = xp.arange(0, 10, dtype=xp.float32) + new_shape = (-10, -1) + with pytest.raises(ValueError, match=r"negative"): + xp.resize(a, new_shape=new_shape) + + +class TestRot90: + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_error(self, xp): + assert_raises(ValueError, xp.rot90, xp.ones(4)) + assert_raises(ValueError, xp.rot90, xp.ones((2, 2, 2)), axes=(0, 1, 2)) + assert_raises(ValueError, xp.rot90, xp.ones((2, 2)), axes=(0, 2)) + assert_raises(ValueError, xp.rot90, xp.ones((2, 2)), axes=(1, 1)) + assert_raises(ValueError, xp.rot90, xp.ones((2, 2, 2)), axes=(-2, 1)) + + def test_error_float_k(self): + assert_raises(TypeError, dpnp.rot90, dpnp.ones((2, 2)), k=2.5) + + def test_basic(self): + a = numpy.array([[0, 1, 2], [3, 4, 5]]) + ia = dpnp.array(a) + + for k in range(-3, 13, 4): + assert_equal(dpnp.rot90(ia, k=k), numpy.rot90(a, k=k)) + for k in range(-2, 13, 4): + assert_equal(dpnp.rot90(ia, k=k), numpy.rot90(a, k=k)) + for k in range(-1, 13, 4): + assert_equal(dpnp.rot90(ia, k=k), numpy.rot90(a, k=k)) + for k in range(0, 13, 4): + assert_equal(dpnp.rot90(ia, k=k), numpy.rot90(a, k=k)) + + assert_equal(dpnp.rot90(dpnp.rot90(ia, axes=(0, 1)), axes=(1, 0)), ia) + assert_equal( + dpnp.rot90(ia, k=1, axes=(1, 0)), dpnp.rot90(ia, k=-1, axes=(0, 1)) + ) + + def test_axes(self): + a = numpy.ones((50, 40, 3)) + ia = dpnp.array(a) + assert_equal(dpnp.rot90(ia), numpy.rot90(a)) + assert_equal(dpnp.rot90(ia, axes=(0, 2)), dpnp.rot90(ia, axes=(0, -1))) + assert_equal(dpnp.rot90(ia, axes=(1, 2)), dpnp.rot90(ia, axes=(-2, -1))) + + @pytest.mark.parametrize( + "axes", [(1, 2), [1, 2], numpy.array([1, 2]), dpnp.array([1, 2])] + ) + def test_axes_type(self, axes): + a = numpy.ones((50, 40, 3)) + ia = dpnp.array(a) + assert_equal(dpnp.rot90(ia, axes=axes), numpy.rot90(a, axes=axes)) + + def test_rotation_axes(self): + a = numpy.arange(8).reshape((2, 2, 2)) + ia = dpnp.array(a) + + assert_equal(dpnp.rot90(ia, axes=(0, 1)), numpy.rot90(a, axes=(0, 1))) + assert_equal(dpnp.rot90(ia, axes=(1, 0)), numpy.rot90(a, axes=(1, 0))) + assert_equal(dpnp.rot90(ia, axes=(1, 2)), numpy.rot90(a, axes=(1, 2))) + + for k in range(1, 5): + assert_equal( + dpnp.rot90(ia, k=k, axes=(2, 0)), + numpy.rot90(a, k=k, axes=(2, 0)), + ) + + class TestTranspose: @pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]]) def test_2d_with_axes(self, axes): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index a41b7d52e77..1b44b207352 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -495,6 +495,7 @@ def test_meshgrid(device): ), pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]), pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]), + pytest.param("rot90", [[1, 2], [3, 4]]), pytest.param("sign", [-5.0, 0.0, 4.5]), pytest.param("signbit", [-5.0, 0.0, 4.5]), pytest.param( @@ -1284,6 +1285,20 @@ def test_out_multi_dot(device): assert_sycl_queue_equal(result.sycl_queue, exec_q) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_resize(device): + dpnp_data = dpnp.arange(10, device=device) + result = dpnp.resize(dpnp_data, (2, 5)) + + expected_queue = dpnp_data.sycl_queue + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + class TestFft: @pytest.mark.parametrize( "func", ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"] diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index b76f9f42b33..5e0e50738b8 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -622,6 +622,7 @@ def test_norm(usm_type, ord, axis): pytest.param("real_if_close", [2.1 + 4e-15j, 5.2 + 3e-16j]), pytest.param("reciprocal", [1.0, 2.0, 4.0, 7.0]), pytest.param("reduce_hypot", [1.0, 2.0, 4.0, 7.0]), + pytest.param("rot90", [[1, 2], [3, 4]]), pytest.param("rsqrt", [1, 8, 27]), pytest.param("sign", [-5.0, 0.0, 4.5]), pytest.param("signbit", [-5.0, 0.0, 4.5]), @@ -1012,6 +1013,15 @@ def test_eigenvalue(func, shape, usm_type): assert a.usm_type == dp_val.usm_type +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_resize(usm_type): + dpnp_data = dp.arange(10, usm_type=usm_type) + result = dp.resize(dpnp_data, (2, 5)) + + assert dpnp_data.usm_type == usm_type + assert result.usm_type == usm_type + + class TestFft: @pytest.mark.parametrize( "func", ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"] diff --git a/tests/third_party/cupy/manipulation_tests/test_add_remove.py b/tests/third_party/cupy/manipulation_tests/test_add_remove.py index c721f9ab31a..264e7208cac 100644 --- a/tests/third_party/cupy/manipulation_tests/test_add_remove.py +++ b/tests/third_party/cupy/manipulation_tests/test_add_remove.py @@ -123,7 +123,6 @@ def test_empty(self, xp): return xp.append(xp.array([]), xp.arange(10)) -@pytest.mark.skip("resize() is not implemented yet") class TestResize(unittest.TestCase): @testing.numpy_cupy_array_equal() def test(self, xp): @@ -137,14 +136,17 @@ def test_remainder(self, xp): def test_shape_int(self, xp): return xp.resize(xp.arange(10), 15) + @pytest.mark.skip("scalar is not supported.") @testing.numpy_cupy_array_equal() def test_scalar(self, xp): return xp.resize(2, (10, 10)) + @pytest.mark.skip("scalar is not supported.") @testing.numpy_cupy_array_equal() def test_scalar_shape_int(self, xp): return xp.resize(2, 10) + @pytest.mark.skip("scalar is not supported.") @testing.numpy_cupy_array_equal() def test_typed_scalar(self, xp): return xp.resize(xp.float32(10.0), (10, 10)) @@ -153,7 +155,7 @@ def test_typed_scalar(self, xp): def test_zerodim(self, xp): return xp.resize(xp.array(0), (10, 10)) - @testing.numpy_cupy_array_equal() + @testing.numpy_cupy_array_equal(type_check=has_support_aspect64()) def test_empty(self, xp): return xp.resize(xp.array([]), (10, 10)) diff --git a/tests/third_party/cupy/manipulation_tests/test_rearrange.py b/tests/third_party/cupy/manipulation_tests/test_rearrange.py index f97f611ff25..79a03520637 100644 --- a/tests/third_party/cupy/manipulation_tests/test_rearrange.py +++ b/tests/third_party/cupy/manipulation_tests/test_rearrange.py @@ -202,7 +202,6 @@ def test_flip_invalid_negative_axis(self, dtype): xp.flip(x, -3) -@pytest.mark.skip("`rot90` isn't supported yet") class TestRot90(unittest.TestCase): @testing.for_all_dtypes() @testing.numpy_cupy_array_equal()