Skip to content

Commit 368a17e

Browse files
Update _empty_like_pair_orderK for support arrays of different shapes (#1224)
* Add res_shape param to _empty_like_pair_orderK * Update tests fot dpctl.tensor.add * Update test for _empty_like_pair_orderK
1 parent 1596a13 commit 368a17e

File tree

4 files changed

+56
-14
lines changed

4 files changed

+56
-14
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def __call__(self, o1, o2, out=None, order="K"):
433433
if out is None:
434434
if order == "K":
435435
out = _empty_like_pair_orderK(
436-
src1, src2, res_dt, res_usm_type, exec_q
436+
src1, src2, res_dt, res_shape, res_usm_type, exec_q
437437
)
438438
else:
439439
if order == "A":
@@ -482,7 +482,7 @@ def __call__(self, o1, o2, out=None, order="K"):
482482
if out is None:
483483
if order == "K":
484484
out = _empty_like_pair_orderK(
485-
src1, buf2, res_dt, res_usm_type, exec_q
485+
src1, buf2, res_dt, res_shape, res_usm_type, exec_q
486486
)
487487
else:
488488
out = dpt.empty(
@@ -524,7 +524,7 @@ def __call__(self, o1, o2, out=None, order="K"):
524524
if out is None:
525525
if order == "K":
526526
out = _empty_like_pair_orderK(
527-
buf1, src2, res_dt, res_usm_type, exec_q
527+
buf1, src2, res_dt, res_shape, res_usm_type, exec_q
528528
)
529529
else:
530530
out = dpt.empty(
@@ -578,7 +578,7 @@ def __call__(self, o1, o2, out=None, order="K"):
578578
if out is None:
579579
if order == "K":
580580
out = _empty_like_pair_orderK(
581-
buf1, buf2, res_dt, res_usm_type, exec_q
581+
buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
582582
)
583583
else:
584584
out = dpt.empty(

dpctl/tensor/_type_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,38 +158,41 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
158158
return dpt.permute_dims(R, inv_perm)
159159

160160

161-
def _empty_like_pair_orderK(X1, X2, dt, usm_type, dev):
161+
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
162162
if not isinstance(X1, dpt.usm_ndarray):
163163
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
164164
if not isinstance(X2, dpt.usm_ndarray):
165165
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
166166
nd1 = X1.ndim
167167
nd2 = X2.ndim
168-
if nd1 > nd2:
168+
if nd1 > nd2 and X1.shape == res_shape:
169169
return _empty_like_orderK(X1, dt, usm_type, dev)
170-
elif nd1 < nd2:
170+
elif nd1 < nd2 and X2.shape == res_shape:
171171
return _empty_like_orderK(X2, dt, usm_type, dev)
172172
fl1 = X1.flags
173173
fl2 = X2.flags
174174
if fl1["C"] or fl2["C"]:
175-
return dpt.empty_like(
176-
X1, dtype=dt, usm_type=usm_type, device=dev, order="C"
175+
return dpt.empty(
176+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
177177
)
178178
if fl1["F"] and fl2["F"]:
179-
return dpt.empty_like(
180-
X1, dtype=dt, usm_type=usm_type, device=dev, order="F"
179+
return dpt.empty(
180+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
181181
)
182182
st1 = list(X1.strides)
183183
st2 = list(X2.strides)
184+
max_ndim = max(nd1, nd2)
185+
st1 += [0] * (max_ndim - len(st1))
186+
st2 += [0] * (max_ndim - len(st2))
184187
perm = sorted(
185-
range(nd1),
188+
range(max_ndim),
186189
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
187190
reverse=True,
188191
)
189-
inv_perm = sorted(range(nd1), key=lambda i: perm[i])
192+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
190193
st1_sorted = [st1[i] for i in perm]
191194
st2_sorted = [st2[i] for i in perm]
192-
sh = X1.shape
195+
sh = res_shape
193196
sh_sorted = tuple(sh[i] for i in perm)
194197
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
195198
if max(min(st1_sorted), min(st2_sorted)) < 0:

dpctl/tests/elementwise/test_add.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,43 @@ def test_add_broadcasting():
156156
assert (dpt.asnumpy(r4) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
157157

158158

159+
def test_add_broadcasting_new_shape():
160+
get_queue_or_skip()
161+
162+
ar1 = dpt.ones((6, 1), dtype="i4")
163+
ar2 = dpt.arange(6, dtype="i4")
164+
165+
r = dpt.add(ar1, ar2)
166+
assert (dpt.asnumpy(r) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all()
167+
168+
r1 = dpt.add(ar2, ar1)
169+
assert (dpt.asnumpy(r1) == np.arange(1, 7, dtype="i4")[np.newaxis, :]).all()
170+
171+
r2 = dpt.add(ar1[::2], ar2[::2])
172+
assert (
173+
dpt.asnumpy(r2) == np.arange(1, 7, dtype="i4")[::2][np.newaxis, :]
174+
).all()
175+
176+
r3 = dpt.empty_like(ar1)
177+
with pytest.raises(TypeError):
178+
dpt.add(ar1, ar2, out=r3)
179+
180+
ar3 = dpt.ones((6, 1), dtype="i4")
181+
ar4 = dpt.ones((1, 6), dtype="i4")
182+
183+
r4 = dpt.add(ar3, ar4)
184+
assert (dpt.asnumpy(r4) == np.full((6, 6), 2, dtype="i4")).all()
185+
186+
r5 = dpt.add(ar4, ar3)
187+
assert (dpt.asnumpy(r5) == np.full((6, 6), 2, dtype="i4")).all()
188+
189+
r6 = dpt.add(ar3[::2], ar4[:, ::2])
190+
assert (dpt.asnumpy(r6) == np.full((3, 3), 2, dtype="i4")).all()
191+
192+
r7 = dpt.add(ar3[::2], ar4)
193+
assert (dpt.asnumpy(r7) == np.full((3, 6), 2, dtype="i4")).all()
194+
195+
159196
def test_add_broadcasting_error():
160197
get_queue_or_skip()
161198
m = dpt.ones((10, 10), dtype="i4")

dpctl/tests/elementwise/test_type_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_type_utils_empty_like_orderK_invalid_args():
8989
3,
9090
),
9191
dpt.int32,
92+
(3,),
9293
"device",
9394
None,
9495
)
@@ -105,6 +106,7 @@ def test_type_utils_empty_like_orderK_invalid_args():
105106
3,
106107
),
107108
dpt.int32,
109+
(10,),
108110
"device",
109111
None,
110112
)

0 commit comments

Comments
 (0)