Skip to content

Commit 91e646d

Browse files
committed
Where result now keeps order of operands
- Now when operands are cast, stride simplification can still be performed on non-C contiguous inputs - Implements _empty_like_triple_orderK to allocate output of where
1 parent b055ff9 commit 91e646d

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,67 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
380380
return dpt.permute_dims(R, inv_perm)
381381

382382

383+
def _empty_like_triple_orderK(X1, X2, X3, dt, res_shape, usm_type, dev):
384+
if not isinstance(X1, dpt.usm_ndarray):
385+
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
386+
if not isinstance(X2, dpt.usm_ndarray):
387+
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
388+
if not isinstance(X3, dpt.usm_ndarray):
389+
raise TypeError(f"Expected usm_ndarray, got {type(X3)}")
390+
nd1 = X1.ndim
391+
nd2 = X2.ndim
392+
nd3 = X3.ndim
393+
if nd1 > nd2 and nd1 > nd3 and X1.shape == res_shape:
394+
return _empty_like_orderK(X1, dt, usm_type, dev)
395+
elif nd1 < nd2 and nd3 < nd2 and X2.shape == res_shape:
396+
return _empty_like_orderK(X2, dt, usm_type, dev)
397+
elif nd1 < nd3 and nd2 < nd3 and X3.shape == res_shape:
398+
return _empty_like_orderK(X3, dt, usm_type, dev)
399+
fl1 = X1.flags
400+
fl2 = X2.flags
401+
fl3 = X3.flags
402+
if fl1["C"] or fl2["C"] or fl3["C"]:
403+
return dpt.empty(
404+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
405+
)
406+
if fl1["F"] and fl2["F"] and fl3["F"]:
407+
return dpt.empty(
408+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
409+
)
410+
st1 = list(X1.strides)
411+
st2 = list(X2.strides)
412+
st3 = list(X3.strides)
413+
max_ndim = max(nd1, nd2, nd3)
414+
st1 += [0] * (max_ndim - len(st1))
415+
st2 += [0] * (max_ndim - len(st2))
416+
st3 += [0] * (max_ndim - len(st3))
417+
perm = sorted(
418+
range(max_ndim),
419+
key=lambda d: (
420+
builtins.abs(st1[d]),
421+
builtins.abs(st2[d]),
422+
builtins.abs(st3[d]),
423+
),
424+
reverse=True,
425+
)
426+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
427+
st1_sorted = [st1[i] for i in perm]
428+
st2_sorted = [st2[i] for i in perm]
429+
st3_sorted = [st3[i] for i in perm]
430+
sh = res_shape
431+
sh_sorted = tuple(sh[i] for i in perm)
432+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
433+
if max(min(st1_sorted), min(st2_sorted), min(st3_sorted)) < 0:
434+
sl = tuple(
435+
slice(None, None, -1)
436+
if (st1_sorted[i] < 0 and st2_sorted[i] < 0 and st3_sorted[i] < 0)
437+
else slice(None, None, None)
438+
for i in range(nd1)
439+
)
440+
R = R[sl]
441+
return dpt.permute_dims(R, inv_perm)
442+
443+
383444
def copy(usm_ary, order="K"):
384445
"""copy(ary, order="K")
385446

dpctl/tensor/_search_functions.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl.tensor._tensor_impl as ti
2020
from dpctl.tensor._manipulation_functions import _broadcast_shapes
2121

22+
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
2223
from ._type_utils import _all_data_types, _can_cast
2324

2425

@@ -121,7 +122,7 @@ def where(condition, x1, x2):
121122
deps = []
122123
wait_list = []
123124
if x1_dtype != dst_dtype:
124-
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
125+
_x1 = _empty_like_orderK(x1, dst_dtype)
125126
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
126127
src=x1, dst=_x1, sycl_queue=exec_q
127128
)
@@ -130,7 +131,7 @@ def where(condition, x1, x2):
130131
wait_list.append(ht_copy1_ev)
131132

132133
if x2_dtype != dst_dtype:
133-
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
134+
_x2 = _empty_like_orderK(x2, dst_dtype)
134135
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
135136
src=x2, dst=_x2, sycl_queue=exec_q
136137
)
@@ -142,25 +143,8 @@ def where(condition, x1, x2):
142143
x1 = dpt.broadcast_to(x1, res_shape)
143144
x2 = dpt.broadcast_to(x2, res_shape)
144145

145-
# dst is F-contiguous when all inputs are F contiguous
146-
# otherwise, defaults to C-contiguous
147-
if all(
148-
(
149-
condition.flags.fnc,
150-
x1.flags.fnc,
151-
x2.flags.fnc,
152-
)
153-
):
154-
order = "F"
155-
else:
156-
order = "C"
157-
158-
dst = dpt.empty(
159-
res_shape,
160-
dtype=dst_dtype,
161-
order=order,
162-
usm_type=dst_usm_type,
163-
sycl_queue=exec_q,
146+
dst = _empty_like_triple_orderK(
147+
condition, x1, x2, dst_dtype, res_shape, dst_usm_type, exec_q
164148
)
165149

166150
hev, _ = ti._where(

0 commit comments

Comments
 (0)