Skip to content

Commit cc7f714

Browse files
committed
Added support for argument axis with value None for arrays with strides for dpctl.tensor.concat().
1 parent bccb694 commit cc7f714

File tree

2 files changed

+82
-51
lines changed

2 files changed

+82
-51
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -372,72 +372,86 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
372372
)
373373

374374

375+
def _concat_axis_None(arrays):
376+
"Implementation of concat(arrays, axis=None)."
377+
res_dtype, res_usm_type, exec_q = _arrays_validation(
378+
arrays, check_ndim=False
379+
)
380+
res_shape = 0
381+
for array in arrays:
382+
res_shape += array.size
383+
res = dpt.empty(
384+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
385+
)
386+
387+
hev_list = []
388+
fill_start = 0
389+
for array in arrays:
390+
fill_end = fill_start + array.size
391+
if array.flags.c_contiguous:
392+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
393+
src=dpt.reshape(array, -1),
394+
dst=res[fill_start:fill_end],
395+
sycl_queue=exec_q,
396+
)
397+
else:
398+
hev, _ = ti._copy_usm_ndarray_for_reshape(
399+
src=array,
400+
dst=res[fill_start:fill_end],
401+
shift=0,
402+
sycl_queue=exec_q,
403+
)
404+
fill_start = fill_end
405+
hev_list.append(hev)
406+
407+
dpctl.SyclEvent.wait_for(hev_list)
408+
return res
409+
410+
375411
def concat(arrays, axis=0):
376412
"""
377413
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
378414
379415
Joins a sequence of arrays along an existing axis.
380416
"""
381417
if axis is None:
382-
res_dtype, res_usm_type, exec_q = _arrays_validation(
383-
arrays, check_ndim=False
384-
)
385-
res_shape = 0
386-
for array in arrays:
387-
res_shape += array.size
388-
res = dpt.empty(
389-
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
390-
)
418+
return _concat_axis_None(arrays)
391419

392-
hev_list = []
393-
fill_start = 0
394-
for array in arrays:
395-
fill_end = fill_start + array.size
396-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
397-
src=dpt.reshape(array, -1),
398-
dst=res[fill_start:fill_end],
399-
sycl_queue=exec_q,
400-
)
401-
fill_start = fill_end
402-
hev_list.append(hev)
420+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
421+
n = len(arrays)
422+
X0 = arrays[0]
403423

404-
dpctl.SyclEvent.wait_for(hev_list)
405-
else:
406-
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
407-
n = len(arrays)
408-
X0 = arrays[0]
424+
axis = normalize_axis_index(axis, X0.ndim)
425+
X0_shape = X0.shape
426+
_check_same_shapes(X0_shape, axis, n, arrays)
409427

410-
axis = normalize_axis_index(axis, X0.ndim)
411-
X0_shape = X0.shape
412-
_check_same_shapes(X0_shape, axis, n, arrays)
428+
res_shape_axis = 0
429+
for X in arrays:
430+
res_shape_axis = res_shape_axis + X.shape[axis]
413431

414-
res_shape_axis = 0
415-
for X in arrays:
416-
res_shape_axis = res_shape_axis + X.shape[axis]
432+
res_shape = tuple(
433+
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
434+
)
417435

418-
res_shape = tuple(
419-
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
420-
)
436+
res = dpt.empty(
437+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
438+
)
421439

422-
res = dpt.empty(
423-
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
440+
hev_list = []
441+
fill_start = 0
442+
for i in range(n):
443+
fill_end = fill_start + arrays[i].shape[axis]
444+
c_shapes_copy = tuple(
445+
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
446+
for j in range(X0.ndim)
424447
)
448+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
449+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
450+
)
451+
fill_start = fill_end
452+
hev_list.append(hev)
425453

426-
hev_list = []
427-
fill_start = 0
428-
for i in range(n):
429-
fill_end = fill_start + arrays[i].shape[axis]
430-
c_shapes_copy = tuple(
431-
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
432-
for j in range(X0.ndim)
433-
)
434-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
435-
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
436-
)
437-
fill_start = fill_end
438-
hev_list.append(hev)
439-
440-
dpctl.SyclEvent.wait_for(hev_list)
454+
dpctl.SyclEvent.wait_for(hev_list)
441455

442456
return res
443457

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,23 @@ def test_concat_3arrays(data):
893893
assert_array_equal(Rnp, dpt.asnumpy(R))
894894

895895

896+
def test_concat_axis_none_strides():
897+
try:
898+
q = dpctl.SyclQueue()
899+
except dpctl.SyclQueueCreationError:
900+
pytest.skip("Queue could not be created")
901+
Xnp = np.arange(0, 18).reshape((6, 3))
902+
X = dpt.asarray(Xnp, sycl_queue=q)
903+
904+
Ynp = np.arange(20, 36).reshape((4, 2, 2))
905+
Y = dpt.asarray(Ynp, sycl_queue=q)
906+
907+
Znp = np.concatenate([Xnp[::2], Ynp[::2]], axis=None)
908+
Z = dpt.concat([X[::2], Y[::2]], axis=None)
909+
910+
assert_array_equal(Znp, dpt.asnumpy(Z))
911+
912+
896913
def test_stack_incorrect_shape():
897914
try:
898915
q = dpctl.SyclQueue()

0 commit comments

Comments
 (0)