Skip to content

Commit 04142cd

Browse files
committed
Adding dpctl.tensor.concat feature and tests.
concat() function concatenates several arrays along one of axis. https://data-apis.org/array-api/latest/API_specification/generated/signatures.manipulation_functions.concat.html
1 parent 299bd7e commit 04142cd

File tree

3 files changed

+247
-1
lines changed

3 files changed

+247
-1
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from dpctl.tensor._manipulation_functions import (
4040
broadcast_arrays,
4141
broadcast_to,
42+
concat,
4243
expand_dims,
4344
flip,
4445
permute_dims,
@@ -66,6 +67,7 @@
6667
"flip",
6768
"reshape",
6869
"roll",
70+
"concat",
6971
"broadcast_arrays",
7072
"broadcast_to",
7173
"expand_dims",

dpctl/tensor/_manipulation_functions.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from itertools import chain, product, repeat
1919

2020
import numpy as np
21-
from numpy.core.numeric import normalize_axis_tuple
21+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
2222

2323
import dpctl
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._tensor_impl as ti
26+
import dpctl.utils as dputils
2627

2728

2829
def _broadcast_strides(X_shape, X_strides, res_ndim):
@@ -285,3 +286,86 @@ def roll(X, shift, axes=None):
285286

286287
dpctl.SyclEvent.wait_for(hev_list)
287288
return res
289+
290+
291+
def concat(arrays, axis=0):
292+
"""
293+
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
294+
295+
Joins a sequence of arrays along an existing axis.
296+
"""
297+
n = len(arrays)
298+
if n == 0:
299+
raise TypeError("Missing 1 required positional argument: 'arrays'")
300+
301+
if not isinstance(arrays, list) and not isinstance(arrays, tuple):
302+
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
303+
304+
for X in arrays:
305+
if not isinstance(X, dpt.usm_ndarray):
306+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
307+
308+
exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
309+
if exec_q is None:
310+
raise ValueError("All the input arrays must have same sycl queue")
311+
312+
res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
313+
if res_usm_type is None:
314+
raise ValueError("All the input arrays must have usm_type")
315+
316+
X0 = arrays[0]
317+
if any(X0.dtype != arrays[i].dtype for i in range(1, n)):
318+
raise ValueError("All the input arrays must have same dtype")
319+
320+
for i in range(1, n):
321+
if X0.ndim != arrays[i].ndim:
322+
raise ValueError(
323+
"All the input arrays must have same number of "
324+
"dimensions, but the array at index 0 has "
325+
f"{X0.ndim} dimension(s) and the array at index "
326+
f"{i} has {arrays[i].ndim} dimension(s)"
327+
)
328+
329+
axis = normalize_axis_index(axis, X0.ndim)
330+
X0_shape = X0.shape
331+
for i in range(1, n):
332+
Xi_shape = arrays[i].shape
333+
for j in range(X0.ndim):
334+
if X0_shape[j] != Xi_shape[j] and j != axis:
335+
raise ValueError(
336+
"All the input array dimensions for the "
337+
"concatenation axis must match exactly, but "
338+
f"along dimension {j}, the array at index 0 "
339+
f"has size {X0_shape[j]} and the array at "
340+
f"index {i} has size {Xi_shape[j]}"
341+
)
342+
343+
res_shape_axis = 0
344+
for X in arrays:
345+
res_shape_axis = res_shape_axis + X.shape[axis]
346+
347+
res_shape = tuple(
348+
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
349+
)
350+
351+
res = dpt.empty(
352+
res_shape, dtype=X0.dtype, usm_type=res_usm_type, sycl_queue=exec_q
353+
)
354+
355+
hev_list = []
356+
fill_start = 0
357+
for i in range(n):
358+
fill_end = fill_start + arrays[i].shape[axis]
359+
c_shapes_copy = tuple(
360+
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
361+
for j in range(X0.ndim)
362+
)
363+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
364+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
365+
)
366+
fill_start = fill_end
367+
hev_list.append(hev)
368+
369+
dpctl.SyclEvent.wait_for(hev_list)
370+
371+
return res

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,163 @@ def test_roll_2d(data):
725725
Y = dpt.roll(X, sh, ax)
726726
Ynp = np.roll(Xnp, sh, ax)
727727
assert_array_equal(Ynp, dpt.asnumpy(Y))
728+
729+
730+
def test_concat_incorrect_type():
731+
Xnp = np.ones((2, 2))
732+
pytest.raises(TypeError, dpt.concat)
733+
pytest.raises(TypeError, dpt.concat, [])
734+
pytest.raises(TypeError, dpt.concat, Xnp)
735+
pytest.raises(TypeError, dpt.concat, [Xnp, Xnp])
736+
737+
738+
def test_concat_incorrect_queue():
739+
try:
740+
q1 = dpctl.SyclQueue()
741+
q2 = dpctl.SyclQueue()
742+
except dpctl.SyclQueueCreationError:
743+
pytest.skip("Queue could not be created")
744+
745+
X = dpt.ones((2, 2), sycl_queue=q1)
746+
Y = dpt.ones((2, 2), sycl_queue=q2)
747+
748+
pytest.raises(ValueError, dpt.concat, [X, Y])
749+
750+
751+
def test_concat_incorrect_dtype():
752+
try:
753+
q = dpctl.SyclQueue()
754+
except dpctl.SyclQueueCreationError:
755+
pytest.skip("Queue could not be created")
756+
757+
X = dpt.ones((2, 2), dtype=np.int64, sycl_queue=q)
758+
Y = dpt.ones((2, 2), dtype=np.uint64, sycl_queue=q)
759+
760+
pytest.raises(ValueError, dpt.concat, [X, Y])
761+
762+
763+
def test_concat_incorrect_ndim():
764+
try:
765+
q = dpctl.SyclQueue()
766+
except dpctl.SyclQueueCreationError:
767+
pytest.skip("Queue could not be created")
768+
769+
X = dpt.ones((2, 2), sycl_queue=q)
770+
Y = dpt.ones((2, 2, 2), sycl_queue=q)
771+
772+
pytest.raises(ValueError, dpt.concat, [X, Y])
773+
774+
775+
@pytest.mark.parametrize(
776+
"data",
777+
[
778+
[(2, 2), (3, 3), 0],
779+
[(2, 2), (3, 3), 1],
780+
[(3, 2), (3, 3), 0],
781+
[(2, 3), (3, 3), 1],
782+
],
783+
)
784+
def test_concat_incorrect_shape(data):
785+
try:
786+
q = dpctl.SyclQueue()
787+
except dpctl.SyclQueueCreationError:
788+
pytest.skip("Queue could not be created")
789+
790+
Xshape, Yshape, axis = data
791+
792+
X = dpt.ones(Xshape, sycl_queue=q)
793+
Y = dpt.ones(Yshape, sycl_queue=q)
794+
795+
pytest.raises(ValueError, dpt.concat, [X, Y], axis)
796+
797+
798+
@pytest.mark.parametrize(
799+
"data",
800+
[
801+
[(6,), 0],
802+
[(2, 3), 1],
803+
[(3, 2), -1],
804+
[(1, 6), 0],
805+
[(2, 1, 3), 2],
806+
],
807+
)
808+
def test_concat_1array(data):
809+
try:
810+
q = dpctl.SyclQueue()
811+
except dpctl.SyclQueueCreationError:
812+
pytest.skip("Queue could not be created")
813+
814+
Xshape, axis = data
815+
816+
Xnp = np.arange(6).reshape(Xshape)
817+
X = dpt.asarray(Xnp, sycl_queue=q)
818+
819+
Ynp = np.concatenate([Xnp], axis=axis)
820+
Y = dpt.concat([X], axis=axis)
821+
822+
assert_array_equal(Ynp, dpt.asnumpy(Y))
823+
824+
Ynp = np.concatenate((Xnp,), axis=axis)
825+
Y = dpt.concat((X,), axis=axis)
826+
827+
assert_array_equal(Ynp, dpt.asnumpy(Y))
828+
829+
830+
@pytest.mark.parametrize(
831+
"data",
832+
[
833+
[(1,), (1,), 0],
834+
[(0, 2), (2, 2), 0],
835+
[(2, 1), (2, 2), -1],
836+
[(2, 2, 2), (2, 1, 2), 1],
837+
],
838+
)
839+
def test_concat_2arrays(data):
840+
try:
841+
q = dpctl.SyclQueue()
842+
except dpctl.SyclQueueCreationError:
843+
pytest.skip("Queue could not be created")
844+
845+
Xshape, Yshape, axis = data
846+
847+
Xnp = np.ones(Xshape)
848+
X = dpt.asarray(Xnp, sycl_queue=q)
849+
850+
Ynp = np.zeros(Yshape)
851+
Y = dpt.asarray(Ynp, sycl_queue=q)
852+
853+
Znp = np.concatenate([Xnp, Ynp], axis=axis)
854+
Z = dpt.concat([X, Y], axis=axis)
855+
856+
assert_array_equal(Znp, dpt.asnumpy(Z))
857+
858+
859+
@pytest.mark.parametrize(
860+
"data",
861+
[
862+
[(1,), (1,), (1,), 0],
863+
[(0, 2), (2, 2), (1, 2), 0],
864+
[(2, 1, 2), (2, 2, 2), (2, 4, 2), 1],
865+
],
866+
)
867+
def test_concat_3arrays(data):
868+
try:
869+
q = dpctl.SyclQueue()
870+
except dpctl.SyclQueueCreationError:
871+
pytest.skip("Queue could not be created")
872+
873+
Xshape, Yshape, Zshape, axis = data
874+
875+
Xnp = np.ones(Xshape)
876+
X = dpt.asarray(Xnp, sycl_queue=q)
877+
878+
Ynp = np.zeros(Yshape)
879+
Y = dpt.asarray(Ynp, sycl_queue=q)
880+
881+
Znp = np.full(Zshape, 2.0)
882+
Z = dpt.asarray(Znp, sycl_queue=q)
883+
884+
Rnp = np.concatenate([Xnp, Ynp, Znp], axis=axis)
885+
R = dpt.concat([X, Y, Z], axis=axis)
886+
887+
assert_array_equal(Rnp, dpt.asnumpy(R))

0 commit comments

Comments
 (0)