Skip to content

Commit 346200e

Browse files
Fixed example for @antonwolfy, added tests based on it
Also added comment to explain the logic of the code.
1 parent 8b03642 commit 346200e

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,27 @@ def _copy_from_usm_ndarray_to_usm_ndarray(dst, src):
302302
src_same_shape = dpt.usm_ndarray(
303303
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
304304
)
305+
elif src.ndim == len(common_shape):
306+
new_src_strides = _broadcast_strides(
307+
src.shape, src.strides, len(common_shape)
308+
)
309+
src_same_shape = dpt.usm_ndarray(
310+
common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides
311+
)
305312
else:
306-
src_same_shape = src
307-
src_same_shape.shape = common_shape
313+
# since broadcasting succeeded, src.ndim is greater because of
314+
# leading sequence of ones, so we trim it
315+
n = len(common_shape)
316+
new_src_strides = _broadcast_strides(
317+
src.shape[-n:], src.strides[-n:], n
318+
)
319+
src_same_shape = dpt.usm_ndarray(
320+
common_shape,
321+
dtype=src.dtype,
322+
buffer=src.usm_data,
323+
strides=new_src_strides,
324+
offset=src._element_offset,
325+
)
308326

309327
_copy_same_shape(dst, src_same_shape)
310328

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,26 @@ def test_setitem_broadcasting_empty_dst_edge_case():
10391039
dst[...] = src
10401040

10411041

1042+
def test_setitem_broadcasting_src_ndim_equal_dst_ndim():
1043+
get_queue_or_skip()
1044+
dst = dpt.ones((2, 3, 4), dtype="i4")
1045+
src = dpt.zeros((2, 1, 4), dtype="i4")
1046+
dst[...] = src
1047+
1048+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1049+
assert np.array_equal(dpt.asnumpy(dst), expected)
1050+
1051+
1052+
def test_setitem_broadcasting_src_ndim_greater_than_dst_ndim():
1053+
get_queue_or_skip()
1054+
dst = dpt.ones((2, 3, 4), dtype="i4")
1055+
src = dpt.zeros((1, 2, 1, 4), dtype="i4")
1056+
dst[...] = src
1057+
1058+
expected = np.zeros(dst.shape, dtype=dst.dtype)
1059+
assert np.array_equal(dpt.asnumpy(dst), expected)
1060+
1061+
10421062
@pytest.mark.parametrize(
10431063
"dtype",
10441064
_all_dtypes,

0 commit comments

Comments
 (0)