Skip to content

Update _empty_like_pair_orderK for support arrays of different shapes #1224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def __call__(self, o1, o2, out=None, order="K"):
if out is None:
if order == "K":
out = _empty_like_pair_orderK(
src1, src2, res_dt, res_usm_type, exec_q
src1, src2, res_dt, res_shape, res_usm_type, exec_q
)
else:
if order == "A":
Expand Down Expand Up @@ -482,7 +482,7 @@ def __call__(self, o1, o2, out=None, order="K"):
if out is None:
if order == "K":
out = _empty_like_pair_orderK(
src1, buf2, res_dt, res_usm_type, exec_q
src1, buf2, res_dt, res_shape, res_usm_type, exec_q
)
else:
out = dpt.empty(
Expand Down Expand Up @@ -524,7 +524,7 @@ def __call__(self, o1, o2, out=None, order="K"):
if out is None:
if order == "K":
out = _empty_like_pair_orderK(
buf1, src2, res_dt, res_usm_type, exec_q
buf1, src2, res_dt, res_shape, res_usm_type, exec_q
)
else:
out = dpt.empty(
Expand Down Expand Up @@ -578,7 +578,7 @@ def __call__(self, o1, o2, out=None, order="K"):
if out is None:
if order == "K":
out = _empty_like_pair_orderK(
buf1, buf2, res_dt, res_usm_type, exec_q
buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
)
else:
out = dpt.empty(
Expand Down
23 changes: 13 additions & 10 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,38 +158,41 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
return dpt.permute_dims(R, inv_perm)


def _empty_like_pair_orderK(X1, X2, dt, usm_type, dev):
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
if not isinstance(X1, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
if not isinstance(X2, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
nd1 = X1.ndim
nd2 = X2.ndim
if nd1 > nd2:
if nd1 > nd2 and X1.shape == res_shape:
return _empty_like_orderK(X1, dt, usm_type, dev)
elif nd1 < nd2:
elif nd1 < nd2 and X2.shape == res_shape:
return _empty_like_orderK(X2, dt, usm_type, dev)
fl1 = X1.flags
fl2 = X2.flags
if fl1["C"] or fl2["C"]:
return dpt.empty_like(
X1, dtype=dt, usm_type=usm_type, device=dev, order="C"
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
)
if fl1["F"] and fl2["F"]:
return dpt.empty_like(
X1, dtype=dt, usm_type=usm_type, device=dev, order="F"
return dpt.empty(
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
)
st1 = list(X1.strides)
st2 = list(X2.strides)
max_ndim = max(nd1, nd2)
st1 += [0] * (max_ndim - len(st1))
st2 += [0] * (max_ndim - len(st2))
perm = sorted(
range(nd1),
range(max_ndim),
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
reverse=True,
)
inv_perm = sorted(range(nd1), key=lambda i: perm[i])
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
st1_sorted = [st1[i] for i in perm]
st2_sorted = [st2[i] for i in perm]
sh = X1.shape
sh = res_shape
sh_sorted = tuple(sh[i] for i in perm)
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
if max(min(st1_sorted), min(st2_sorted)) < 0:
Expand Down
37 changes: 37 additions & 0 deletions dpctl/tests/elementwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,43 @@ def test_add_broadcasting():
assert (dpt.asnumpy(r4) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()


def test_add_broadcasting_new_shape():
get_queue_or_skip()

ar1 = dpt.ones((6, 1), dtype="i4")
ar2 = dpt.arange(6, dtype="i4")

r = dpt.add(ar1, ar2)
assert (dpt.asnumpy(r) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all()

r1 = dpt.add(ar2, ar1)
assert (dpt.asnumpy(r1) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all()

r2 = dpt.add(ar1[::2], ar2[::2])
assert (
dpt.asnumpy(r2) == np.arange(1, 7, dtype="i4")[::2][np.newaxis, :]
).all()

r3 = dpt.empty_like(ar1)
with pytest.raises(TypeError):
dpt.add(ar1, ar2, out=r3)

ar3 = dpt.ones((6, 1), dtype="i4")
ar4 = dpt.ones((1, 6), dtype="i4")

r4 = dpt.add(ar3, ar4)
assert (dpt.asnumpy(r4) == np.full((6, 6), 2, dtype="i4")).all()

r5 = dpt.add(ar4, ar3)
assert (dpt.asnumpy(r5) == np.full((6, 6), 2, dtype="i4")).all()

r6 = dpt.add(ar3[::2], ar4[:, ::2])
assert (dpt.asnumpy(r6) == np.full((3, 3), 2, dtype="i4")).all()

r7 = dpt.add(ar3[::2], ar4)
assert (dpt.asnumpy(r7) == np.full((3, 6), 2, dtype="i4")).all()


def test_add_broadcasting_error():
get_queue_or_skip()
m = dpt.ones((10, 10), dtype="i4")
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tests/elementwise/test_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_type_utils_empty_like_orderK_invalid_args():
3,
),
dpt.int32,
(3,),
"device",
None,
)
Expand All @@ -105,6 +106,7 @@ def test_type_utils_empty_like_orderK_invalid_args():
3,
),
dpt.int32,
(10,),
"device",
None,
)
Expand Down