Skip to content

Added support for argument axis with value None for dpctl.tensor.concat(). #1125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ def roll(X, shift, axis=None):
return res


def _arrays_validation(arrays):
def _arrays_validation(arrays, check_ndim=True):
n = len(arrays)
if n == 0:
raise TypeError("Missing 1 required positional argument: 'arrays'")
raise TypeError("Missing 1 required positional argument: 'arrays'.")

if not isinstance(arrays, (list, tuple)):
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
Expand All @@ -335,11 +335,11 @@ def _arrays_validation(arrays):

exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
if exec_q is None:
raise ValueError("All the input arrays must have same sycl queue")
raise ValueError("All the input arrays must have same sycl queue.")

res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
if res_usm_type is None:
raise ValueError("All the input arrays must have usm_type")
raise ValueError("All the input arrays must have usm_type.")

X0 = arrays[0]
_supported_dtype(Xi.dtype for Xi in arrays)
Expand All @@ -348,13 +348,14 @@ def _arrays_validation(arrays):
for i in range(1, n):
res_dtype = np.promote_types(res_dtype, arrays[i])

for i in range(1, n):
if X0.ndim != arrays[i].ndim:
raise ValueError(
"All the input arrays must have same number of dimensions, "
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
f"array at index {i} has {arrays[i].ndim} dimension(s)"
)
if check_ndim:
for i in range(1, n):
if X0.ndim != arrays[i].ndim:
raise ValueError(
"All the input arrays must have same number of dimensions, "
f"but the array at index 0 has {X0.ndim} dimension(s) and "
f"the array at index {i} has {arrays[i].ndim} dimension(s)."
)
return res_dtype, res_usm_type, exec_q


Expand All @@ -367,18 +368,56 @@ def _check_same_shapes(X0_shape, axis, n, arrays):
"All the input array dimensions for the concatenation "
f"axis must match exactly, but along dimension {j}, the "
f"array at index 0 has size {X0j} and the array "
f"at index {i} has size {Xi_shape[j]}"
f"at index {i} has size {Xi_shape[j]}."
)


def _concat_axis_None(arrays):
"Implementation of concat(arrays, axis=None)."
res_dtype, res_usm_type, exec_q = _arrays_validation(
arrays, check_ndim=False
)
res_shape = 0
for array in arrays:
res_shape += array.size
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)

hev_list = []
fill_start = 0
for array in arrays:
fill_end = fill_start + array.size
if array.flags.c_contiguous:
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpt.reshape(array, -1),
dst=res[fill_start:fill_end],
sycl_queue=exec_q,
)
else:
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=array,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
hev_list.append(hev)

dpctl.SyclEvent.wait_for(hev_list)
return res


def concat(arrays, axis=0):
"""
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray

Joins a sequence of arrays along an existing axis.
"""
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
if axis is None:
return _concat_axis_None(arrays)

res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
n = len(arrays)
X0 = arrays[0]

Expand Down
18 changes: 18 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def test_concat_1array(data):
[(0, 2), (2, 2), 0],
[(2, 1), (2, 2), -1],
[(2, 2, 2), (2, 1, 2), 1],
[(3, 3, 3), (2, 2), None],
],
)
def test_concat_2arrays(data):
Expand Down Expand Up @@ -892,6 +893,23 @@ def test_concat_3arrays(data):
assert_array_equal(Rnp, dpt.asnumpy(R))


def test_concat_axis_none_strides():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
pytest.skip("Queue could not be created")
Xnp = np.arange(0, 18).reshape((6, 3))
X = dpt.asarray(Xnp, sycl_queue=q)

Ynp = np.arange(20, 36).reshape((4, 2, 2))
Y = dpt.asarray(Ynp, sycl_queue=q)

Znp = np.concatenate([Xnp[::2], Ynp[::2]], axis=None)
Z = dpt.concat([X[::2], Y[::2]], axis=None)

assert_array_equal(Znp, dpt.asnumpy(Z))


def test_stack_incorrect_shape():
try:
q = dpctl.SyclQueue()
Expand Down