Skip to content

Commit 26b2eaa

Browse files
Merge pull request #1137 from IntelPython/impl_unstack_moveaxis_swapaxes
add unstack, moveaxis, swapaxes
2 parents 3a22dd9 + 18e680b commit 26b2eaa

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@
6969
finfo,
7070
flip,
7171
iinfo,
72+
moveaxis,
7273
permute_dims,
7374
result_type,
7475
roll,
7576
squeeze,
7677
stack,
78+
swapaxes,
79+
unstack,
7780
)
7881
from dpctl.tensor._print import (
7982
get_print_options,
@@ -143,6 +146,9 @@
143146
"complex128",
144147
"iinfo",
145148
"finfo",
149+
"unstack",
150+
"moveaxis",
151+
"swapaxes",
146152
"can_cast",
147153
"result_type",
148154
"meshgrid",

dpctl/tensor/_manipulation_functions.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -741,6 +741,119 @@ def finfo(dtype):
741741
return finfo_object(dtype)
742742

743743

744+
def unstack(X, axis=0):
745+
"""unstack(x, axis=0)
746+
747+
Splits an array in a sequence of arrays along the given axis.
748+
749+
Args:
750+
x (usm_ndarray): input array
751+
752+
axis (int, optional): axis along which `x` is unstacked.
753+
If `x` has rank (i.e, number of dimensions) `N`,
754+
a valid `axis` must reside in the half-open interval `[-N, N)`.
755+
Default: `0`.
756+
757+
Returns:
758+
Tuple[usm_ndarray,...]: A tuple of arrays.
759+
760+
Raises:
761+
AxisError: if the `axis` value is invalid.
762+
"""
763+
if not isinstance(X, dpt.usm_ndarray):
764+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
765+
766+
axis = normalize_axis_index(axis, X.ndim)
767+
Y = dpt.moveaxis(X, axis, 0)
768+
769+
return tuple(Y[i] for i in range(Y.shape[0]))
770+
771+
772+
def moveaxis(X, src, dst):
773+
"""moveaxis(x, src, dst)
774+
775+
Moves axes of an array to new positions.
776+
777+
Args:
778+
x (usm_ndarray): input array
779+
780+
src (int or a sequence of int):
781+
Original positions of the axes to move.
782+
These must be unique. If `x` has rank (i.e., number of
783+
dimensions) `N`, a valid `axis` must be in the
784+
half-open interval `[-N, N)`.
785+
786+
dst (int or a sequence of int):
787+
Destination positions for each of the original axes.
788+
These must also be unique. If `x` has rank
789+
(i.e., number of dimensions) `N`, a valid `axis` must be
790+
in the half-open interval `[-N, N)`.
791+
792+
Returns:
793+
usm_narray: Array with moved axes.
794+
The returned array must has the same data type as `x`,
795+
is created on the same device as `x` and has the same
796+
USM allocation type as `x`.
797+
798+
Raises:
799+
AxisError: if `axis` value is invalid.
800+
"""
801+
if not isinstance(X, dpt.usm_ndarray):
802+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
803+
804+
if not isinstance(src, (tuple, list)):
805+
src = (src,)
806+
807+
if not isinstance(dst, (tuple, list)):
808+
dst = (dst,)
809+
810+
src = normalize_axis_tuple(src, X.ndim, "src")
811+
dst = normalize_axis_tuple(dst, X.ndim, "dst")
812+
ind = list(range(0, X.ndim))
813+
for i in range(len(src)):
814+
ind.remove(src[i]) # using the value here which is the same as index
815+
ind.insert(dst[i], src[i])
816+
817+
return dpt.permute_dims(X, tuple(ind))
818+
819+
820+
def swapaxes(X, axis1, axis2):
821+
"""swapaxes(x, axis1, axis2)
822+
823+
Interchanges two axes of an array.
824+
825+
Args:
826+
x (usm_ndarray): input array
827+
828+
axis1 (int): First axis.
829+
If `x` has rank (i.e., number of dimensions) `N`,
830+
a valid `axis` must be in the half-open interval `[-N, N)`.
831+
832+
axis2 (int): Second axis.
833+
If `x` has rank (i.e., number of dimensions) `N`,
834+
a valid `axis` must be in the half-open interval `[-N, N)`.
835+
836+
Returns:
837+
usm_narray: Array with swapped axes.
838+
The returned array must has the same data type as `x`,
839+
is created on the same device as `x` and has the same USM
840+
allocation type as `x`.
841+
842+
Raises:
843+
AxisError: if `axis` value is invalid.
844+
"""
845+
if not isinstance(X, dpt.usm_ndarray):
846+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
847+
848+
axis1 = normalize_axis_index(axis1, X.ndim, "axis1")
849+
axis2 = normalize_axis_index(axis2, X.ndim, "axis2")
850+
851+
ind = list(range(0, X.ndim))
852+
ind[axis1] = axis2
853+
ind[axis2] = axis1
854+
return dpt.permute_dims(X, tuple(ind))
855+
856+
744857
def _supported_dtype(dtypes):
745858
for dtype in dtypes:
746859
if dtype.char not in "?bBhHiIlLqQefdFD":

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,3 +1046,79 @@ def test_result_type():
10461046
X_np = [np.ones((2), dtype=np.int64), np.int32, "float16"]
10471047

10481048
assert dpt.result_type(*X) == np.result_type(*X_np)
1049+
1050+
1051+
def test_swapaxes_1d():
1052+
x = np.array([[1, 2, 3]])
1053+
exp = np.swapaxes(x, 0, 1)
1054+
1055+
y = dpt.asarray([[1, 2, 3]])
1056+
res = dpt.swapaxes(y, 0, 1)
1057+
1058+
assert_array_equal(exp, dpt.asnumpy(res))
1059+
1060+
1061+
def test_swapaxes_2d():
1062+
x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1063+
exp = np.swapaxes(x, 0, 2)
1064+
1065+
y = dpt.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
1066+
res = dpt.swapaxes(y, 0, 2)
1067+
1068+
assert_array_equal(exp, dpt.asnumpy(res))
1069+
1070+
1071+
def test_moveaxis_1axis():
1072+
x = np.arange(60).reshape((3, 4, 5))
1073+
exp = np.moveaxis(x, 0, -1)
1074+
1075+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1076+
res = dpt.moveaxis(y, 0, -1)
1077+
1078+
assert_array_equal(exp, dpt.asnumpy(res))
1079+
1080+
1081+
def test_moveaxis_2axes():
1082+
x = np.arange(60).reshape((3, 4, 5))
1083+
exp = np.moveaxis(x, [0, 1], [-1, -2])
1084+
1085+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1086+
res = dpt.moveaxis(y, [0, 1], [-1, -2])
1087+
1088+
assert_array_equal(exp, dpt.asnumpy(res))
1089+
1090+
1091+
def test_moveaxis_3axes():
1092+
x = np.arange(60).reshape((3, 4, 5))
1093+
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
1094+
1095+
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1096+
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])
1097+
1098+
assert_array_equal(exp, dpt.asnumpy(res))
1099+
1100+
1101+
def test_unstack_axis0():
1102+
y = dpt.reshape(dpt.arange(6), (2, 3))
1103+
res = dpt.unstack(y)
1104+
1105+
assert_array_equal(dpt.asnumpy(y[0, ...]), dpt.asnumpy(res[0]))
1106+
assert_array_equal(dpt.asnumpy(y[1, ...]), dpt.asnumpy(res[1]))
1107+
1108+
1109+
def test_unstack_axis1():
1110+
y = dpt.reshape(dpt.arange(6), (2, 3))
1111+
res = dpt.unstack(y, 1)
1112+
1113+
assert_array_equal(dpt.asnumpy(y[:, 0, ...]), dpt.asnumpy(res[0]))
1114+
assert_array_equal(dpt.asnumpy(y[:, 1, ...]), dpt.asnumpy(res[1]))
1115+
assert_array_equal(dpt.asnumpy(y[:, 2, ...]), dpt.asnumpy(res[2]))
1116+
1117+
1118+
def test_unstack_axis2():
1119+
y = dpt.reshape(dpt.arange(60), (4, 5, 3))
1120+
res = dpt.unstack(y, 2)
1121+
1122+
assert_array_equal(dpt.asnumpy(y[:, :, 0, ...]), dpt.asnumpy(res[0]))
1123+
assert_array_equal(dpt.asnumpy(y[:, :, 1, ...]), dpt.asnumpy(res[1]))
1124+
assert_array_equal(dpt.asnumpy(y[:, :, 2, ...]), dpt.asnumpy(res[2]))

0 commit comments

Comments
 (0)