Skip to content

Commit 09f7085

Browse files
Add dpnp.roll implementation and update dpnp.rollaxis/dpnp.moveaxis (#1517)
* Add a new dpnp.roll function using dpctl.tensor.roll impl * Update dpnp.rollaxis and add tests for it * Update dpnp.moveaxis * Update implemantations of rolls funcs and fix remarks * Raise ValueError instead of numpy.AxisError * Add support for axis=None and shift is tuple in dpnp.roll
1 parent fa3cd55 commit 09f7085

File tree

6 files changed

+153
-27
lines changed

6 files changed

+153
-27
lines changed

.github/workflows/conda-package.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ env:
3232
third_party/cupy/logic_tests/test_truth.py
3333
third_party/cupy/manipulation_tests/test_basic.py
3434
third_party/cupy/manipulation_tests/test_join.py
35+
third_party/cupy/manipulation_tests/test_rearrange.py
36+
third_party/cupy/manipulation_tests/test_transpose.py
3537
third_party/cupy/math_tests/test_explog.py
3638
third_party/cupy/math_tests/test_misc.py
3739
third_party/cupy/math_tests/test_trigonometric.py

doc/reference/manipulation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Transpose-like operations
3535
:nosignatures:
3636

3737
dpnp.moveaxis
38+
dpnp.roll
3839
dpnp.rollaxis
3940
dpnp.swapaxes
4041
dpnp.ndarray.T

dpnp/dpnp_iface_manipulation.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import dpctl.tensor as dpt
4444
import numpy
45+
from numpy.core.numeric import normalize_axis_index
4546

4647
import dpnp
4748
from dpnp.dpnp_algo import *
@@ -67,6 +68,7 @@
6768
"repeat",
6869
"reshape",
6970
"result_type",
71+
"roll",
7072
"rollaxis",
7173
"shape",
7274
"squeeze",
@@ -692,7 +694,8 @@ def moveaxis(a, source, destination):
692694
Limitations
693695
-----------
694696
Parameters `a` is supported as either :class:`dpnp.ndarray`
695-
or :class:`dpctl.tensor.usm_ndarray`.
697+
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
698+
will be raised.
696699
Input array data types are limited by supported DPNP :ref:`Data types`.
697700
Otherwise ``TypeError`` exception will be raised.
698701
@@ -910,18 +913,81 @@ def result_type(*arrays_and_dtypes):
910913
return dpt.result_type(*usm_arrays_and_dtypes)
911914

912915

913-
def rollaxis(x1, axis, start=0):
916+
def roll(x, shift, axis=None):
917+
"""
918+
Roll the elements of an array by a number of positions along a given axis.
919+
920+
Array elements that roll beyond the last position are re-introduced
921+
at the first position. Array elements that roll beyond the first position
922+
are re-introduced at the last position.
923+
924+
For full documentation refer to :obj:`numpy.roll`.
925+
926+
Returns
927+
-------
928+
dpnp.ndarray
929+
An array with the same data type as `x`
930+
and whose elements, relative to `x`, are shifted.
931+
932+
Limitations
933+
-----------
934+
Parameter `x` is supported either as :class:`dpnp.ndarray`
935+
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
936+
will be raised.
937+
Input array data types are limited by supported DPNP :ref:`Data types`.
938+
939+
940+
See Also
941+
--------
942+
:obj:`dpnp.moveaxis` : Move array axes to new positions.
943+
:obj:`dpnp.rollaxis` : Roll the specified axis backwards
944+
until it lies in a given position.
945+
946+
Examples
947+
--------
948+
>>> import dpnp as np
949+
>>> x1 = np.arange(10)
950+
>>> np.roll(x1, 2)
951+
array([8, 9, 0, 1, 2, 3, 4, 5, 6, 7])
952+
953+
>>> np.roll(x1, -2)
954+
array([2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
955+
956+
>>> x2 = np.reshape(x1, (2, 5))
957+
>>> np.roll(x2, 1, axis=0)
958+
array([[5, 6, 7, 8, 9],
959+
[0, 1, 2, 3, 4]])
960+
961+
>>> np.roll(x2, (2, 1), axis=(1, 0))
962+
array([[8, 9, 5, 6, 7],
963+
[3, 4, 0, 1, 2]])
964+
965+
"""
966+
if axis is None:
967+
return roll(x.reshape(-1), shift, 0).reshape(x.shape)
968+
dpt_array = dpnp.get_usm_ndarray(x)
969+
return dpnp_array._create_from_usm_ndarray(
970+
dpt.roll(dpt_array, shift=shift, axis=axis)
971+
)
972+
973+
974+
def rollaxis(x, axis, start=0):
914975
"""
915976
Roll the specified axis backwards, until it lies in a given position.
916977
917978
For full documentation refer to :obj:`numpy.rollaxis`.
918979
980+
Returns
981+
-------
982+
dpnp.ndarray
983+
An array with the same data type as `x` where the specified axis
984+
has been repositioned to the desired position.
985+
919986
Limitations
920987
-----------
921-
Input array is supported as :obj:`dpnp.ndarray`.
922-
Parameter ``axis`` is supported as integer only.
923-
Parameter ``start`` is limited by ``-a.ndim <= start <= a.ndim``.
924-
Otherwise the function will be executed sequentially on CPU.
988+
Parameter `x` is supported either as :class:`dpnp.ndarray`
989+
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
990+
will be raised.
925991
Input array data types are limited by supported DPNP :ref:`Data types`.
926992
927993
See Also
@@ -943,19 +1009,19 @@ def rollaxis(x1, axis, start=0):
9431009
9441010
"""
9451011

946-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
947-
if x1_desc:
948-
if not isinstance(axis, int):
949-
pass
950-
elif start < -x1_desc.ndim or start > x1_desc.ndim:
951-
pass
952-
else:
953-
start_norm = start + x1_desc.ndim if start < 0 else start
954-
destination = start_norm - 1 if start_norm > axis else start_norm
955-
956-
return dpnp.moveaxis(x1_desc.get_pyobj(), axis, destination)
957-
958-
return call_origin(numpy.rollaxis, x1, axis, start)
1012+
n = x.ndim
1013+
axis = normalize_axis_index(axis, n)
1014+
if start < 0:
1015+
start += n
1016+
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
1017+
if not (0 <= start < n + 1):
1018+
raise ValueError(msg % ("start", -n, "start", n + 1, start))
1019+
if axis < start:
1020+
start -= 1
1021+
if axis == start:
1022+
return x
1023+
dpt_array = dpnp.get_usm_ndarray(x)
1024+
return dpnp.moveaxis(dpt_array, source=axis, destination=start)
9591025

9601026

9611027
def shape(a):

tests/test_arraymanipulation.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,68 @@ def test_2D_array2(self):
605605
def test_generator(self):
606606
with assert_warns(FutureWarning):
607607
dpnp.vstack((numpy.arange(3) for _ in range(2)))
608+
609+
610+
class TestRollaxis:
611+
data = [
612+
(0, 0),
613+
(0, 1),
614+
(0, 2),
615+
(0, 3),
616+
(0, 4),
617+
(1, 0),
618+
(1, 1),
619+
(1, 2),
620+
(1, 3),
621+
(1, 4),
622+
(2, 0),
623+
(2, 1),
624+
(2, 2),
625+
(2, 3),
626+
(2, 4),
627+
(3, 0),
628+
(3, 1),
629+
(3, 2),
630+
(3, 3),
631+
(3, 4),
632+
]
633+
634+
@pytest.mark.parametrize(
635+
("axis", "start"),
636+
[
637+
(-5, 0),
638+
(0, -5),
639+
(4, 0),
640+
(0, 5),
641+
],
642+
)
643+
def test_exceptions(self, axis, start):
644+
a = dpnp.arange(1 * 2 * 3 * 4).reshape(1, 2, 3, 4)
645+
assert_raises(ValueError, dpnp.rollaxis, a, axis, start)
646+
647+
def test_results(self):
648+
np_a = numpy.arange(1 * 2 * 3 * 4).reshape(1, 2, 3, 4)
649+
dp_a = dpnp.array(np_a)
650+
for i, j in self.data:
651+
# positive axis, positive start
652+
res = dpnp.rollaxis(dp_a, axis=i, start=j)
653+
exp = numpy.rollaxis(np_a, axis=i, start=j)
654+
assert res.shape == exp.shape
655+
656+
# negative axis, positive start
657+
ip = i + 1
658+
res = dpnp.rollaxis(dp_a, axis=-ip, start=j)
659+
exp = numpy.rollaxis(np_a, axis=-ip, start=j)
660+
assert res.shape == exp.shape
661+
662+
# positive axis, negative start
663+
jp = j + 1 if j < 4 else j
664+
res = dpnp.rollaxis(dp_a, axis=i, start=-jp)
665+
exp = numpy.rollaxis(np_a, axis=i, start=-jp)
666+
assert res.shape == exp.shape
667+
668+
# negative axis, negative start
669+
ip = i + 1
670+
jp = j + 1 if j < 4 else j
671+
res = dpnp.rollaxis(dp_a, axis=-ip, start=-jp)
672+
exp = numpy.rollaxis(np_a, axis=-ip, start=-jp)

tests/third_party/cupy/manipulation_tests/test_rearrange.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
{"shape": (5, 2), "shift": (2, 1, 3), "axis": 0},
2626
{"shape": (5, 2), "shift": (2, 1, 3), "axis": None},
2727
)
28-
@pytest.mark.skip("`roll` isn't supported yet")
2928
class TestRoll(unittest.TestCase):
3029
@testing.for_all_dtypes()
3130
@testing.numpy_cupy_array_equal()
@@ -38,12 +37,9 @@ def test_roll(self, xp, dtype):
3837
def test_roll_cupy_shift(self, xp, dtype):
3938
x = testing.shaped_arange(self.shape, xp, dtype)
4039
shift = self.shift
41-
if xp is cupy:
42-
shift = cupy.array(shift)
4340
return xp.roll(x, shift, axis=self.axis)
4441

4542

46-
@pytest.mark.skip("`roll` isn't supported yet")
4743
class TestRollTypeError(unittest.TestCase):
4844
def test_roll_invalid_shift(self):
4945
for xp in (numpy, cupy):
@@ -66,7 +62,6 @@ def test_roll_invalid_axis_type(self):
6662
{"shape": (5, 2), "shift": 1, "axis": -3},
6763
{"shape": (5, 2), "shift": 1, "axis": (1, -3)},
6864
)
69-
@pytest.mark.skip("`roll` isn't supported yet")
7065
class TestRollValueError(unittest.TestCase):
7166
def test_roll_invalid(self):
7267
for xp in (numpy, cupy):
@@ -78,8 +73,6 @@ def test_roll_invalid_cupy_shift(self):
7873
for xp in (numpy, cupy):
7974
x = testing.shaped_arange(self.shape, xp)
8075
shift = self.shift
81-
if xp is cupy:
82-
shift = cupy.array(shift)
8376
with pytest.raises(ValueError):
8477
xp.roll(x, shift, axis=self.axis)
8578

tests/third_party/cupy/manipulation_tests/test_transpose.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def test_rollaxis(self, xp):
103103
a = testing.shaped_arange((2, 3, 4), xp)
104104
return xp.rollaxis(a, 2)
105105

106-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
107106
def test_rollaxis_failure(self):
108107
for xp in (numpy, cupy):
109108
a = testing.shaped_arange((2, 3, 4), xp)

0 commit comments

Comments
 (0)