Skip to content

Commit 8e04e06

Browse files
Merge pull request #867 from IntelPython/adding-concat-feature
Added dpctl.tensor.concat feature and tests
2 parents 299bd7e + 600ac20 commit 8e04e06

File tree

4 files changed

+256
-1
lines changed

4 files changed

+256
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
* Provided pybind11 example for functions working on `dpctl.tensor.usm_ndarray` container applying oneMKL functions [#780](https://github.com/IntelPython/dpctl/pull/780), [#793](https://github.com/IntelPython/dpctl/pull/793), [#819](https://github.com/IntelPython/dpctl/pull/819). The example was expanded to demonstrate implementing iterative linear solvers (Chebyshev solver, and Conjugate-Gradient solver) by asynchronously submitting individual SYCL kernels from Python [#821](https://github.com/IntelPython/dpctl/pull/821), [#833](https://github.com/IntelPython/dpctl/pull/833), [#838](https://github.com/IntelPython/dpctl/pull/838).
3232
* Wrote manual page about working with `dpctl.SyclQueue` [#829](https://github.com/IntelPython/dpctl/pull/829).
3333
* Added cmake scripts to dpctl package layout and a way to query the location [#853](https://github.com/IntelPython/dpctl/pull/853).
34+
* Implemented `dpctl.tensor.concat` function from array-API [#867](https://github.com/IntelPython/dpctl/867).
3435

3536

3637
### Changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from dpctl.tensor._manipulation_functions import (
4040
broadcast_arrays,
4141
broadcast_to,
42+
concat,
4243
expand_dims,
4344
flip,
4445
permute_dims,
@@ -66,6 +67,7 @@
6667
"flip",
6768
"reshape",
6869
"roll",
70+
"concat",
6971
"broadcast_arrays",
7072
"broadcast_to",
7173
"expand_dims",

dpctl/tensor/_manipulation_functions.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
from itertools import chain, product, repeat
1919

2020
import numpy as np
21-
from numpy.core.numeric import normalize_axis_tuple
21+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
2222

2323
import dpctl
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._tensor_impl as ti
26+
import dpctl.utils as dputils
2627

2728

2829
def _broadcast_strides(X_shape, X_strides, res_ndim):
@@ -285,3 +286,90 @@ def roll(X, shift, axes=None):
285286

286287
dpctl.SyclEvent.wait_for(hev_list)
287288
return res
289+
290+
291+
def concat(arrays, axis=0):
292+
"""
293+
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
294+
295+
Joins a sequence of arrays along an existing axis.
296+
"""
297+
n = len(arrays)
298+
if n == 0:
299+
raise TypeError("Missing 1 required positional argument: 'arrays'")
300+
301+
if not isinstance(arrays, (list, tuple)):
302+
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
303+
304+
for X in arrays:
305+
if not isinstance(X, dpt.usm_ndarray):
306+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
307+
308+
exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
309+
if exec_q is None:
310+
raise ValueError("All the input arrays must have same sycl queue")
311+
312+
res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
313+
if res_usm_type is None:
314+
raise ValueError("All the input arrays must have usm_type")
315+
316+
X0 = arrays[0]
317+
if not all(Xi.dtype.char in "?bBhHiIlLqQefdFD" for Xi in arrays):
318+
raise ValueError("Unsupported dtype encountered.")
319+
320+
res_dtype = X0.dtype
321+
for i in range(1, n):
322+
res_dtype = np.promote_types(res_dtype, arrays[i])
323+
324+
for i in range(1, n):
325+
if X0.ndim != arrays[i].ndim:
326+
raise ValueError(
327+
"All the input arrays must have same number of "
328+
"dimensions, but the array at index 0 has "
329+
f"{X0.ndim} dimension(s) and the array at index "
330+
f"{i} has {arrays[i].ndim} dimension(s)"
331+
)
332+
333+
axis = normalize_axis_index(axis, X0.ndim)
334+
X0_shape = X0.shape
335+
for i in range(1, n):
336+
Xi_shape = arrays[i].shape
337+
for j in range(X0.ndim):
338+
if X0_shape[j] != Xi_shape[j] and j != axis:
339+
raise ValueError(
340+
"All the input array dimensions for the "
341+
"concatenation axis must match exactly, but "
342+
f"along dimension {j}, the array at index 0 "
343+
f"has size {X0_shape[j]} and the array at "
344+
f"index {i} has size {Xi_shape[j]}"
345+
)
346+
347+
res_shape_axis = 0
348+
for X in arrays:
349+
res_shape_axis = res_shape_axis + X.shape[axis]
350+
351+
res_shape = tuple(
352+
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
353+
)
354+
355+
res = dpt.empty(
356+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
357+
)
358+
359+
hev_list = []
360+
fill_start = 0
361+
for i in range(n):
362+
fill_end = fill_start + arrays[i].shape[axis]
363+
c_shapes_copy = tuple(
364+
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
365+
for j in range(X0.ndim)
366+
)
367+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
368+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
369+
)
370+
fill_start = fill_end
371+
hev_list.append(hev)
372+
373+
dpctl.SyclEvent.wait_for(hev_list)
374+
375+
return res

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,167 @@ def test_roll_2d(data):
725725
Y = dpt.roll(X, sh, ax)
726726
Ynp = np.roll(Xnp, sh, ax)
727727
assert_array_equal(Ynp, dpt.asnumpy(Y))
728+
729+
730+
def test_concat_incorrect_type():
731+
Xnp = np.ones((2, 2))
732+
pytest.raises(TypeError, dpt.concat)
733+
pytest.raises(TypeError, dpt.concat, [])
734+
pytest.raises(TypeError, dpt.concat, Xnp)
735+
pytest.raises(TypeError, dpt.concat, [Xnp, Xnp])
736+
737+
738+
def test_concat_incorrect_queue():
739+
try:
740+
q1 = dpctl.SyclQueue()
741+
q2 = dpctl.SyclQueue()
742+
except dpctl.SyclQueueCreationError:
743+
pytest.skip("Queue could not be created")
744+
745+
X = dpt.ones((2, 2), sycl_queue=q1)
746+
Y = dpt.ones((2, 2), sycl_queue=q2)
747+
748+
pytest.raises(ValueError, dpt.concat, [X, Y])
749+
750+
751+
def test_concat_different_dtype():
752+
try:
753+
q = dpctl.SyclQueue()
754+
except dpctl.SyclQueueCreationError:
755+
pytest.skip("Queue could not be created")
756+
757+
X = dpt.ones((2, 2), dtype=np.int64, sycl_queue=q)
758+
Y = dpt.ones((3, 2), dtype=np.uint32, sycl_queue=q)
759+
760+
XY = dpt.concat([X, Y])
761+
762+
assert XY.dtype is X.dtype
763+
assert XY.shape == (5, 2)
764+
assert XY.sycl_queue == q
765+
766+
767+
def test_concat_incorrect_ndim():
768+
try:
769+
q = dpctl.SyclQueue()
770+
except dpctl.SyclQueueCreationError:
771+
pytest.skip("Queue could not be created")
772+
773+
X = dpt.ones((2, 2), sycl_queue=q)
774+
Y = dpt.ones((2, 2, 2), sycl_queue=q)
775+
776+
pytest.raises(ValueError, dpt.concat, [X, Y])
777+
778+
779+
@pytest.mark.parametrize(
780+
"data",
781+
[
782+
[(2, 2), (3, 3), 0],
783+
[(2, 2), (3, 3), 1],
784+
[(3, 2), (3, 3), 0],
785+
[(2, 3), (3, 3), 1],
786+
],
787+
)
788+
def test_concat_incorrect_shape(data):
789+
try:
790+
q = dpctl.SyclQueue()
791+
except dpctl.SyclQueueCreationError:
792+
pytest.skip("Queue could not be created")
793+
794+
Xshape, Yshape, axis = data
795+
796+
X = dpt.ones(Xshape, sycl_queue=q)
797+
Y = dpt.ones(Yshape, sycl_queue=q)
798+
799+
pytest.raises(ValueError, dpt.concat, [X, Y], axis)
800+
801+
802+
@pytest.mark.parametrize(
803+
"data",
804+
[
805+
[(6,), 0],
806+
[(2, 3), 1],
807+
[(3, 2), -1],
808+
[(1, 6), 0],
809+
[(2, 1, 3), 2],
810+
],
811+
)
812+
def test_concat_1array(data):
813+
try:
814+
q = dpctl.SyclQueue()
815+
except dpctl.SyclQueueCreationError:
816+
pytest.skip("Queue could not be created")
817+
818+
Xshape, axis = data
819+
820+
Xnp = np.arange(6).reshape(Xshape)
821+
X = dpt.asarray(Xnp, sycl_queue=q)
822+
823+
Ynp = np.concatenate([Xnp], axis=axis)
824+
Y = dpt.concat([X], axis=axis)
825+
826+
assert_array_equal(Ynp, dpt.asnumpy(Y))
827+
828+
Ynp = np.concatenate((Xnp,), axis=axis)
829+
Y = dpt.concat((X,), axis=axis)
830+
831+
assert_array_equal(Ynp, dpt.asnumpy(Y))
832+
833+
834+
@pytest.mark.parametrize(
835+
"data",
836+
[
837+
[(1,), (1,), 0],
838+
[(0, 2), (2, 2), 0],
839+
[(2, 1), (2, 2), -1],
840+
[(2, 2, 2), (2, 1, 2), 1],
841+
],
842+
)
843+
def test_concat_2arrays(data):
844+
try:
845+
q = dpctl.SyclQueue()
846+
except dpctl.SyclQueueCreationError:
847+
pytest.skip("Queue could not be created")
848+
849+
Xshape, Yshape, axis = data
850+
851+
Xnp = np.ones(Xshape)
852+
X = dpt.asarray(Xnp, sycl_queue=q)
853+
854+
Ynp = np.zeros(Yshape)
855+
Y = dpt.asarray(Ynp, sycl_queue=q)
856+
857+
Znp = np.concatenate([Xnp, Ynp], axis=axis)
858+
Z = dpt.concat([X, Y], axis=axis)
859+
860+
assert_array_equal(Znp, dpt.asnumpy(Z))
861+
862+
863+
@pytest.mark.parametrize(
864+
"data",
865+
[
866+
[(1,), (1,), (1,), 0],
867+
[(0, 2), (2, 2), (1, 2), 0],
868+
[(2, 1, 2), (2, 2, 2), (2, 4, 2), 1],
869+
],
870+
)
871+
def test_concat_3arrays(data):
872+
try:
873+
q = dpctl.SyclQueue()
874+
except dpctl.SyclQueueCreationError:
875+
pytest.skip("Queue could not be created")
876+
877+
Xshape, Yshape, Zshape, axis = data
878+
879+
Xnp = np.ones(Xshape)
880+
X = dpt.asarray(Xnp, sycl_queue=q)
881+
882+
Ynp = np.zeros(Yshape)
883+
Y = dpt.asarray(Ynp, sycl_queue=q)
884+
885+
Znp = np.full(Zshape, 2.0)
886+
Z = dpt.asarray(Znp, sycl_queue=q)
887+
888+
Rnp = np.concatenate([Xnp, Ynp, Znp], axis=axis)
889+
R = dpt.concat([X, Y, Z], axis=axis)
890+
891+
assert_array_equal(Rnp, dpt.asnumpy(R))

0 commit comments

Comments
 (0)