Skip to content

Commit 7453bf7

Browse files
Corrected order="K" handling for binary function in some cases
When both inputs must be promoted, e.g. `divide(boolean, integral)`, order=K can create temporary buffers using empty_likeK, and then the result could be created using _empty_pair_likeK utilities. This resolves the test failure for `divide`.
1 parent 85d468c commit 7453bf7

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -470,23 +470,36 @@ def __call__(self, o1, o2, order="K"):
470470
if order in ["K", "A"]:
471471
if src1.flags.f_contiguous and src2.flags.f_contiguous:
472472
order = "F"
473-
else:
473+
elif src1.flags.c_contiguous and src2.flags.c_contiguous:
474474
order = "C"
475-
buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
475+
else:
476+
order = "C" if order == "A" else "K"
477+
if order == "K":
478+
buf1 = _empty_like_orderK(src1, buf1_dt)
479+
else:
480+
buf1 = dpt.empty_like(src1, dtype=buf1_dt, order=order)
476481
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
477482
src=src1, dst=buf1, sycl_queue=exec_q
478483
)
479-
buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
484+
if order == "K":
485+
buf2 = _empty_like_orderK(src2, buf2_dt)
486+
else:
487+
buf2 = dpt.empty_like(src2, dtype=buf2_dt, order=order)
480488
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
481489
src=src2, dst=buf2, sycl_queue=exec_q
482490
)
483-
r = dpt.empty(
484-
res_shape,
485-
dtype=res_dt,
486-
usm_type=res_usm_type,
487-
sycl_queue=exec_q,
488-
order=order,
489-
)
491+
if order == "K":
492+
r = _empty_like_pair_orderK(
493+
buf1, buf2, res_dt, res_usm_type, exec_q
494+
)
495+
else:
496+
r = dpt.empty(
497+
res_shape,
498+
dtype=res_dt,
499+
usm_type=res_usm_type,
500+
sycl_queue=exec_q,
501+
order=order,
502+
)
490503
buf1 = dpt.broadcast_to(buf1, res_shape)
491504
buf2 = dpt.broadcast_to(buf2, res_shape)
492505
ht_, _ = self.binary_fn_(

0 commit comments

Comments
 (0)