Skip to content

Commit 2273287

Browse files
Merge pull request #1125 from IntelPython/fix_gh_1122
Added support for argument axis with value None for dpctl.tensor.concat().
2 parents 2fafa76 + cc7f714 commit 2273287

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,10 @@ def roll(X, shift, axis=None):
411411
return res
412412

413413

414-
def _arrays_validation(arrays):
414+
def _arrays_validation(arrays, check_ndim=True):
415415
n = len(arrays)
416416
if n == 0:
417-
raise TypeError("Missing 1 required positional argument: 'arrays'")
417+
raise TypeError("Missing 1 required positional argument: 'arrays'.")
418418

419419
if not isinstance(arrays, (list, tuple)):
420420
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
@@ -425,11 +425,11 @@ def _arrays_validation(arrays):
425425

426426
exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
427427
if exec_q is None:
428-
raise ValueError("All the input arrays must have same sycl queue")
428+
raise ValueError("All the input arrays must have same sycl queue.")
429429

430430
res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
431431
if res_usm_type is None:
432-
raise ValueError("All the input arrays must have usm_type")
432+
raise ValueError("All the input arrays must have usm_type.")
433433

434434
X0 = arrays[0]
435435
_supported_dtype(Xi.dtype for Xi in arrays)
@@ -438,13 +438,14 @@ def _arrays_validation(arrays):
438438
for i in range(1, n):
439439
res_dtype = np.promote_types(res_dtype, arrays[i])
440440

441-
for i in range(1, n):
442-
if X0.ndim != arrays[i].ndim:
443-
raise ValueError(
444-
"All the input arrays must have same number of dimensions, "
445-
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
446-
f"array at index {i} has {arrays[i].ndim} dimension(s)"
447-
)
441+
if check_ndim:
442+
for i in range(1, n):
443+
if X0.ndim != arrays[i].ndim:
444+
raise ValueError(
445+
"All the input arrays must have same number of dimensions, "
446+
f"but the array at index 0 has {X0.ndim} dimension(s) and "
447+
f"the array at index {i} has {arrays[i].ndim} dimension(s)."
448+
)
448449
return res_dtype, res_usm_type, exec_q
449450

450451

@@ -457,10 +458,46 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
457458
"All the input array dimensions for the concatenation "
458459
f"axis must match exactly, but along dimension {j}, the "
459460
f"array at index 0 has size {X0j} and the array "
460-
f"at index {i} has size {Xi_shape[j]}"
461+
f"at index {i} has size {Xi_shape[j]}."
461462
)
462463

463464

465+
def _concat_axis_None(arrays):
466+
"Implementation of concat(arrays, axis=None)."
467+
res_dtype, res_usm_type, exec_q = _arrays_validation(
468+
arrays, check_ndim=False
469+
)
470+
res_shape = 0
471+
for array in arrays:
472+
res_shape += array.size
473+
res = dpt.empty(
474+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
475+
)
476+
477+
hev_list = []
478+
fill_start = 0
479+
for array in arrays:
480+
fill_end = fill_start + array.size
481+
if array.flags.c_contiguous:
482+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
483+
src=dpt.reshape(array, -1),
484+
dst=res[fill_start:fill_end],
485+
sycl_queue=exec_q,
486+
)
487+
else:
488+
hev, _ = ti._copy_usm_ndarray_for_reshape(
489+
src=array,
490+
dst=res[fill_start:fill_end],
491+
shift=0,
492+
sycl_queue=exec_q,
493+
)
494+
fill_start = fill_end
495+
hev_list.append(hev)
496+
497+
dpctl.SyclEvent.wait_for(hev_list)
498+
return res
499+
500+
464501
def concat(arrays, axis=0):
465502
"""concat(arrays, axis)
466503
@@ -486,8 +523,10 @@ def concat(arrays, axis=0):
486523
of the output array is determined by USM allocation type promotion
487524
rules.
488525
"""
489-
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
526+
if axis is None:
527+
return _concat_axis_None(arrays)
490528

529+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
491530
n = len(arrays)
492531
X0 = arrays[0]
493532

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def test_concat_1array(data):
839839
[(0, 2), (2, 2), 0],
840840
[(2, 1), (2, 2), -1],
841841
[(2, 2, 2), (2, 1, 2), 1],
842+
[(3, 3, 3), (2, 2), None],
842843
],
843844
)
844845
def test_concat_2arrays(data):
@@ -892,6 +893,23 @@ def test_concat_3arrays(data):
892893
assert_array_equal(Rnp, dpt.asnumpy(R))
893894

894895

896+
def test_concat_axis_none_strides():
897+
try:
898+
q = dpctl.SyclQueue()
899+
except dpctl.SyclQueueCreationError:
900+
pytest.skip("Queue could not be created")
901+
Xnp = np.arange(0, 18).reshape((6, 3))
902+
X = dpt.asarray(Xnp, sycl_queue=q)
903+
904+
Ynp = np.arange(20, 36).reshape((4, 2, 2))
905+
Y = dpt.asarray(Ynp, sycl_queue=q)
906+
907+
Znp = np.concatenate([Xnp[::2], Ynp[::2]], axis=None)
908+
Z = dpt.concat([X[::2], Y[::2]], axis=None)
909+
910+
assert_array_equal(Znp, dpt.asnumpy(Z))
911+
912+
895913
def test_stack_incorrect_shape():
896914
try:
897915
q = dpctl.SyclQueue()

0 commit comments

Comments
 (0)