Skip to content

Commit bccb694

Browse files
committed
Added support for argument axis with value None for dpctl.tensor.concat().
1 parent 8f828f2 commit bccb694

File tree

2 files changed

+66
-40
lines changed

2 files changed

+66
-40
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,10 @@ def roll(X, shift, axis=None):
321321
return res
322322

323323

324-
def _arrays_validation(arrays):
324+
def _arrays_validation(arrays, check_ndim=True):
325325
n = len(arrays)
326326
if n == 0:
327-
raise TypeError("Missing 1 required positional argument: 'arrays'")
327+
raise TypeError("Missing 1 required positional argument: 'arrays'.")
328328

329329
if not isinstance(arrays, (list, tuple)):
330330
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
@@ -335,11 +335,11 @@ def _arrays_validation(arrays):
335335

336336
exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
337337
if exec_q is None:
338-
raise ValueError("All the input arrays must have same sycl queue")
338+
raise ValueError("All the input arrays must have same sycl queue.")
339339

340340
res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
341341
if res_usm_type is None:
342-
raise ValueError("All the input arrays must have usm_type")
342+
raise ValueError("All the input arrays must have usm_type.")
343343

344344
X0 = arrays[0]
345345
_supported_dtype(Xi.dtype for Xi in arrays)
@@ -348,13 +348,14 @@ def _arrays_validation(arrays):
348348
for i in range(1, n):
349349
res_dtype = np.promote_types(res_dtype, arrays[i])
350350

351-
for i in range(1, n):
352-
if X0.ndim != arrays[i].ndim:
353-
raise ValueError(
354-
"All the input arrays must have same number of dimensions, "
355-
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
356-
f"array at index {i} has {arrays[i].ndim} dimension(s)"
357-
)
351+
if check_ndim:
352+
for i in range(1, n):
353+
if X0.ndim != arrays[i].ndim:
354+
raise ValueError(
355+
"All the input arrays must have same number of dimensions, "
356+
f"but the array at index 0 has {X0.ndim} dimension(s) and "
357+
f"the array at index {i} has {arrays[i].ndim} dimension(s)."
358+
)
358359
return res_dtype, res_usm_type, exec_q
359360

360361

@@ -367,7 +368,7 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
367368
"All the input array dimensions for the concatenation "
368369
f"axis must match exactly, but along dimension {j}, the "
369370
f"array at index 0 has size {X0j} and the array "
370-
f"at index {i} has size {Xi_shape[j]}"
371+
f"at index {i} has size {Xi_shape[j]}."
371372
)
372373

373374

@@ -377,42 +378,66 @@ def concat(arrays, axis=0):
377378
378379
Joins a sequence of arrays along an existing axis.
379380
"""
380-
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
381-
382-
n = len(arrays)
383-
X0 = arrays[0]
381+
if axis is None:
382+
res_dtype, res_usm_type, exec_q = _arrays_validation(
383+
arrays, check_ndim=False
384+
)
385+
res_shape = 0
386+
for array in arrays:
387+
res_shape += array.size
388+
res = dpt.empty(
389+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
390+
)
384391

385-
axis = normalize_axis_index(axis, X0.ndim)
386-
X0_shape = X0.shape
387-
_check_same_shapes(X0_shape, axis, n, arrays)
392+
hev_list = []
393+
fill_start = 0
394+
for array in arrays:
395+
fill_end = fill_start + array.size
396+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
397+
src=dpt.reshape(array, -1),
398+
dst=res[fill_start:fill_end],
399+
sycl_queue=exec_q,
400+
)
401+
fill_start = fill_end
402+
hev_list.append(hev)
388403

389-
res_shape_axis = 0
390-
for X in arrays:
391-
res_shape_axis = res_shape_axis + X.shape[axis]
404+
dpctl.SyclEvent.wait_for(hev_list)
405+
else:
406+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
407+
n = len(arrays)
408+
X0 = arrays[0]
392409

393-
res_shape = tuple(
394-
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
395-
)
410+
axis = normalize_axis_index(axis, X0.ndim)
411+
X0_shape = X0.shape
412+
_check_same_shapes(X0_shape, axis, n, arrays)
396413

397-
res = dpt.empty(
398-
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
399-
)
414+
res_shape_axis = 0
415+
for X in arrays:
416+
res_shape_axis = res_shape_axis + X.shape[axis]
400417

401-
hev_list = []
402-
fill_start = 0
403-
for i in range(n):
404-
fill_end = fill_start + arrays[i].shape[axis]
405-
c_shapes_copy = tuple(
406-
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
407-
for j in range(X0.ndim)
418+
res_shape = tuple(
419+
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
408420
)
409-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
410-
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
421+
422+
res = dpt.empty(
423+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
411424
)
412-
fill_start = fill_end
413-
hev_list.append(hev)
414425

415-
dpctl.SyclEvent.wait_for(hev_list)
426+
hev_list = []
427+
fill_start = 0
428+
for i in range(n):
429+
fill_end = fill_start + arrays[i].shape[axis]
430+
c_shapes_copy = tuple(
431+
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
432+
for j in range(X0.ndim)
433+
)
434+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
435+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
436+
)
437+
fill_start = fill_end
438+
hev_list.append(hev)
439+
440+
dpctl.SyclEvent.wait_for(hev_list)
416441

417442
return res
418443

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def test_concat_1array(data):
839839
[(0, 2), (2, 2), 0],
840840
[(2, 1), (2, 2), -1],
841841
[(2, 2, 2), (2, 1, 2), 1],
842+
[(3, 3, 3), (2, 2), None],
842843
],
843844
)
844845
def test_concat_2arrays(data):

0 commit comments

Comments
 (0)