-
Notifications
You must be signed in to change notification settings - Fork 30
Added support for argument axis with value None for dpctl.tensor.concat(). #1125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
View rendered docs @ https://intelpython.github.io/dpctl/pulls/1125/index.html |
Array API standard conformance tests for dpctl=0.14.2=py310h76be34b_41 ran successfully. |
For readability sake it may be better to separate implementation of def concat(arrays, axis=0):
"docstring"
if axis is None:
return _concat_axis_None(arrays)
# code to handle integral axis here In principle, the implementation of
I think Hence, I suggest to implement the def _concat_axis_None_alt(arrays):
"Implementation of concat(arrays, axis=None)"
res_dtype, res_usm_type, exec_q = _arrays_validation(
arrays, check_ndim=False
)
res_shape = 0
for array in arrays:
res_shape += array.size
res = dpt.empty(
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
)
hev_list = []
fill_start = 0
for array in arrays:
fill_end = fill_start + array.size
if array.flags.c_contiguous:
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=dpt.reshape(array, -1),
dst=res[fill_start:fill_end],
sycl_queue=exec_q,
)
else:
hev, _ = ti._copy_usm_ndarray_for_reshape(
src=array,
dst=res[fill_start:fill_end],
shift=0,
sycl_queue=exec_q,
)
fill_start = fill_end
hev_list.append(hev)
dpctl.SyclEvent.wait_for(hev_list)
return res This offers some tangible performance gains:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please apply changes as outlined in the earlier comment, and also please add a test to covert concat(..., axis=None)
for both C-contiguous and strided inputs.
…es for dpctl.tensor.concat().
Array API standard conformance tests for dpctl=0.14.2=py310h76be34b_51 ran successfully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @npolina4
Deleted rendered PR docs from intelpython.github.com/dpctl, latest should be updated shortly. 🤞 |
Array API standard conformance tests for dpctl=0.14.2=py310h76be34b_51 ran successfully. |
Closes #1122