Skip to content

Added dpctl.tensor.stack feature and tests #872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Wrote manual page about working with `dpctl.SyclQueue` [#829](https://github.com/IntelPython/dpctl/pull/829).
* Added cmake scripts to dpctl package layout and a way to query the location [#853](https://github.com/IntelPython/dpctl/pull/853).
* Implemented `dpctl.tensor.concat` function from array-API [#867](https://github.com/IntelPython/dpctl/867).
* Implemented `dpctl.tensor.stack` function from array-API [#872](https://github.com/IntelPython/dpctl/872).


### Changed
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
permute_dims,
roll,
squeeze,
stack,
)
from dpctl.tensor._reshape import reshape
from dpctl.tensor._usmarray import usm_ndarray
Expand All @@ -68,6 +69,7 @@
"reshape",
"roll",
"concat",
"stack",
"broadcast_arrays",
"broadcast_to",
"expand_dims",
Expand Down
78 changes: 63 additions & 15 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,7 @@ def roll(X, shift, axes=None):
return res


def concat(arrays, axis=0):
"""
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray

Joins a sequence of arrays along an existing axis.
"""
def _arrays_validation(arrays):
n = len(arrays)
if n == 0:
raise TypeError("Missing 1 required positional argument: 'arrays'")
Expand Down Expand Up @@ -324,11 +319,23 @@ def concat(arrays, axis=0):
for i in range(1, n):
if X0.ndim != arrays[i].ndim:
raise ValueError(
"All the input arrays must have same number of "
"dimensions, but the array at index 0 has "
f"{X0.ndim} dimension(s) and the array at index "
f"{i} has {arrays[i].ndim} dimension(s)"
"All the input arrays must have same number of dimensions, "
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
f"array at index {i} has {arrays[i].ndim} dimension(s)"
)
return res_dtype, res_usm_type, exec_q


def concat(arrays, axis=0):
"""
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray

Joins a sequence of arrays along an existing axis.
"""
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)

n = len(arrays)
X0 = arrays[0]

axis = normalize_axis_index(axis, X0.ndim)
X0_shape = X0.shape
Expand All @@ -337,11 +344,10 @@ def concat(arrays, axis=0):
for j in range(X0.ndim):
if X0_shape[j] != Xi_shape[j] and j != axis:
raise ValueError(
"All the input array dimensions for the "
"concatenation axis must match exactly, but "
f"along dimension {j}, the array at index 0 "
f"has size {X0_shape[j]} and the array at "
f"index {i} has size {Xi_shape[j]}"
"All the input array dimensions for the concatenation "
f"axis must match exactly, but along dimension {j}, the "
f"array at index 0 has size {X0_shape[j]} and the array "
f"at index {i} has size {Xi_shape[j]}"
)

res_shape_axis = 0
Expand Down Expand Up @@ -373,3 +379,45 @@ def concat(arrays, axis=0):
dpctl.SyclEvent.wait_for(hev_list)

return res


def stack(arrays, axis=0):
"""
stack(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray

Joins a sequence of arrays along a new axis.
"""
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)

n = len(arrays)
X0 = arrays[0]
res_ndim = X0.ndim + 1
axis = normalize_axis_index(axis, res_ndim)
X0_shape = X0.shape

for i in range(1, n):
if X0_shape != arrays[i].shape:
raise ValueError("All input arrays must have the same shape")

res_shape = tuple(
X0_shape[i - 1 * (i >= axis)] if i != axis else n
for i in range(res_ndim)
)

res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev_list = []
for i in range(n):
c_shapes_copy = tuple(
i if j == axis else np.s_[:] for j in range(res_ndim)
)
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
)
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)

return res
110 changes: 110 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,3 +889,113 @@ def test_concat_3arrays(data):
R = dpt.concat([X, Y, Z], axis=axis)

assert_array_equal(Rnp, dpt.asnumpy(R))


def test_stack_incorrect_shape():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

X = dpt.ones((1,), sycl_queue=q)
Y = dpt.ones((2,), sycl_queue=q)

pytest.raises(ValueError, dpt.stack, [X, Y], 0)


@pytest.mark.parametrize(
"data",
[
[(6,), 0],
[(2, 3), 1],
[(3, 2), -1],
[(1, 6), 2],
[(2, 1, 3), 2],
],
)
def test_stack_1array(data):
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

shape, axis = data

Xnp = np.arange(6).reshape(shape)
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.stack([Xnp], axis=axis)
Y = dpt.stack([X], axis=axis)

assert_array_equal(Ynp, dpt.asnumpy(Y))

Ynp = np.stack((Xnp,), axis=axis)
Y = dpt.stack((X,), axis=axis)

assert_array_equal(Ynp, dpt.asnumpy(Y))


@pytest.mark.parametrize(
"data",
[
[(1,), 0],
[(0, 2), 0],
[(2, 0), 0],
[(2, 3), 0],
[(2, 3), 1],
[(2, 3), 2],
[(2, 3), -1],
[(2, 3), -2],
[(2, 2, 2), 1],
],
)
def test_stack_2arrays(data):
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

shape, axis = data

Xnp = np.ones(shape)
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.zeros(shape)
Y = dpt.asarray(Ynp, sycl_queue=q)

Znp = np.stack([Xnp, Ynp], axis=axis)
print(Znp.shape)
Z = dpt.stack([X, Y], axis=axis)

assert_array_equal(Znp, dpt.asnumpy(Z))


@pytest.mark.parametrize(
"data",
[
[(1,), 0],
[(0, 2), 0],
[(2, 1, 2), 1],
],
)
def test_stack_3arrays(data):
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")

shape, axis = data

Xnp = np.ones(shape)
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.zeros(shape)
Y = dpt.asarray(Ynp, sycl_queue=q)

Znp = np.full(shape, 2.0)
Z = dpt.asarray(Znp, sycl_queue=q)

Rnp = np.stack([Xnp, Ynp, Znp], axis=axis)
R = dpt.stack([X, Y, Z], axis=axis)

assert_array_equal(Rnp, dpt.asnumpy(R))