Skip to content

Commit 3ddf51c

Browse files
Corrected order='K' support in astype
Array API tests pointed out an error in implementation of order='K' in dpctl.tensor.astype. Moved _empty_like_orderK and fried from _type_utils to _copy_utils and used it to implement astype. Modified import statement in _elementwise_common where _empty_like* are used.
1 parent 07faf2b commit 3ddf51c

File tree

4 files changed

+103
-117
lines changed

4 files changed

+103
-117
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import builtins
1617
import operator
1718

1819
import numpy as np
@@ -361,6 +362,96 @@ def copy(usm_ary, order="K"):
361362
return R
362363

363364

365+
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
366+
"""Returns empty array like `x`, using order='K'
367+
368+
For an array `x` that was obtained by permutation of a contiguous
369+
array the returned array will have the same shape and the same
370+
strides as `x`.
371+
"""
372+
if not isinstance(X, dpt.usm_ndarray):
373+
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
374+
if usm_type is None:
375+
usm_type = X.usm_type
376+
if dev is None:
377+
dev = X.device
378+
fl = X.flags
379+
if fl["C"] or X.size <= 1:
380+
return dpt.empty_like(
381+
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
382+
)
383+
elif fl["F"]:
384+
return dpt.empty_like(
385+
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
386+
)
387+
st = list(X.strides)
388+
perm = sorted(
389+
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
390+
)
391+
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
392+
st_sorted = [st[i] for i in perm]
393+
sh = X.shape
394+
sh_sorted = tuple(sh[i] for i in perm)
395+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
396+
if min(st_sorted) < 0:
397+
sl = tuple(
398+
slice(None, None, -1)
399+
if st_sorted[i] < 0
400+
else slice(None, None, None)
401+
for i in range(X.ndim)
402+
)
403+
R = R[sl]
404+
return dpt.permute_dims(R, inv_perm)
405+
406+
407+
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
408+
if not isinstance(X1, dpt.usm_ndarray):
409+
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
410+
if not isinstance(X2, dpt.usm_ndarray):
411+
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
412+
nd1 = X1.ndim
413+
nd2 = X2.ndim
414+
if nd1 > nd2 and X1.shape == res_shape:
415+
return _empty_like_orderK(X1, dt, usm_type, dev)
416+
elif nd1 < nd2 and X2.shape == res_shape:
417+
return _empty_like_orderK(X2, dt, usm_type, dev)
418+
fl1 = X1.flags
419+
fl2 = X2.flags
420+
if fl1["C"] or fl2["C"]:
421+
return dpt.empty(
422+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
423+
)
424+
if fl1["F"] and fl2["F"]:
425+
return dpt.empty(
426+
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
427+
)
428+
st1 = list(X1.strides)
429+
st2 = list(X2.strides)
430+
max_ndim = max(nd1, nd2)
431+
st1 += [0] * (max_ndim - len(st1))
432+
st2 += [0] * (max_ndim - len(st2))
433+
perm = sorted(
434+
range(max_ndim),
435+
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
436+
reverse=True,
437+
)
438+
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
439+
st1_sorted = [st1[i] for i in perm]
440+
st2_sorted = [st2[i] for i in perm]
441+
sh = res_shape
442+
sh_sorted = tuple(sh[i] for i in perm)
443+
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
444+
if max(min(st1_sorted), min(st2_sorted)) < 0:
445+
sl = tuple(
446+
slice(None, None, -1)
447+
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
448+
else slice(None, None, None)
449+
for i in range(nd1)
450+
)
451+
R = R[sl]
452+
return dpt.permute_dims(R, inv_perm)
453+
454+
364455
def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
365456
""" astype(array, new_dtype, order="K", casting="unsafe", \
366457
copy=True)
@@ -432,26 +523,15 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
432523
"Unrecognized value of the order keyword. "
433524
"Recognized values are 'A', 'C', 'F', or 'K'"
434525
)
435-
R = dpt.usm_ndarray(
436-
usm_ary.shape,
437-
dtype=target_dtype,
438-
buffer=usm_ary.usm_type,
439-
order=copy_order,
440-
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
441-
)
442-
if order == "K" and (not c_contig and not f_contig):
443-
original_strides = usm_ary.strides
444-
ind = sorted(
445-
range(usm_ary.ndim),
446-
key=lambda i: abs(original_strides[i]),
447-
reverse=True,
448-
)
449-
new_strides = tuple(R.strides[ind[i]] for i in ind)
526+
if order == "K":
527+
R = _empty_like_orderK(usm_ary, target_dtype)
528+
else:
450529
R = dpt.usm_ndarray(
451530
usm_ary.shape,
452531
dtype=target_dtype,
453-
buffer=R.usm_data,
454-
strides=new_strides,
532+
buffer=usm_ary.usm_type,
533+
order=copy_order,
534+
buffer_ctor_kwargs={"queue": usm_ary.sycl_queue},
455535
)
456536
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
457537
return R

dpctl/tensor/_elementwise_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@
2626
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
2727
from dpctl.utils import ExecutionPlacementError
2828

29+
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
2930
from ._type_utils import (
3031
_acceptance_fn_default,
31-
_empty_like_orderK,
32-
_empty_like_pair_orderK,
3332
_find_buf_dtype,
3433
_find_buf_dtype2,
3534
_find_inplace_dtype,

dpctl/tensor/_type_utils.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import builtins
18-
1917
import dpctl.tensor as dpt
2018
import dpctl.tensor._tensor_impl as ti
2119

@@ -116,96 +114,6 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
116114
return can_cast_v
117115

118116

119-
def _empty_like_orderK(X, dt, usm_type=None, dev=None):
120-
"""Returns empty array like `x`, using order='K'
121-
122-
For an array `x` that was obtained by permutation of a contiguous
123-
array the returned array will have the same shape and the same
124-
strides as `x`.
125-
"""
126-
if not isinstance(X, dpt.usm_ndarray):
127-
raise TypeError(f"Expected usm_ndarray, got {type(X)}")
128-
if usm_type is None:
129-
usm_type = X.usm_type
130-
if dev is None:
131-
dev = X.device
132-
fl = X.flags
133-
if fl["C"] or X.size <= 1:
134-
return dpt.empty_like(
135-
X, dtype=dt, usm_type=usm_type, device=dev, order="C"
136-
)
137-
elif fl["F"]:
138-
return dpt.empty_like(
139-
X, dtype=dt, usm_type=usm_type, device=dev, order="F"
140-
)
141-
st = list(X.strides)
142-
perm = sorted(
143-
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
144-
)
145-
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
146-
st_sorted = [st[i] for i in perm]
147-
sh = X.shape
148-
sh_sorted = tuple(sh[i] for i in perm)
149-
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
150-
if min(st_sorted) < 0:
151-
sl = tuple(
152-
slice(None, None, -1)
153-
if st_sorted[i] < 0
154-
else slice(None, None, None)
155-
for i in range(X.ndim)
156-
)
157-
R = R[sl]
158-
return dpt.permute_dims(R, inv_perm)
159-
160-
161-
def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
162-
if not isinstance(X1, dpt.usm_ndarray):
163-
raise TypeError(f"Expected usm_ndarray, got {type(X1)}")
164-
if not isinstance(X2, dpt.usm_ndarray):
165-
raise TypeError(f"Expected usm_ndarray, got {type(X2)}")
166-
nd1 = X1.ndim
167-
nd2 = X2.ndim
168-
if nd1 > nd2 and X1.shape == res_shape:
169-
return _empty_like_orderK(X1, dt, usm_type, dev)
170-
elif nd1 < nd2 and X2.shape == res_shape:
171-
return _empty_like_orderK(X2, dt, usm_type, dev)
172-
fl1 = X1.flags
173-
fl2 = X2.flags
174-
if fl1["C"] or fl2["C"]:
175-
return dpt.empty(
176-
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="C"
177-
)
178-
if fl1["F"] and fl2["F"]:
179-
return dpt.empty(
180-
res_shape, dtype=dt, usm_type=usm_type, device=dev, order="F"
181-
)
182-
st1 = list(X1.strides)
183-
st2 = list(X2.strides)
184-
max_ndim = max(nd1, nd2)
185-
st1 += [0] * (max_ndim - len(st1))
186-
st2 += [0] * (max_ndim - len(st2))
187-
perm = sorted(
188-
range(max_ndim),
189-
key=lambda d: (builtins.abs(st1[d]), builtins.abs(st2[d])),
190-
reverse=True,
191-
)
192-
inv_perm = sorted(range(max_ndim), key=lambda i: perm[i])
193-
st1_sorted = [st1[i] for i in perm]
194-
st2_sorted = [st2[i] for i in perm]
195-
sh = res_shape
196-
sh_sorted = tuple(sh[i] for i in perm)
197-
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
198-
if max(min(st1_sorted), min(st2_sorted)) < 0:
199-
sl = tuple(
200-
slice(None, None, -1)
201-
if (st1_sorted[i] < 0 and st2_sorted[i] < 0)
202-
else slice(None, None, None)
203-
for i in range(nd1)
204-
)
205-
R = R[sl]
206-
return dpt.permute_dims(R, inv_perm)
207-
208-
209117
def _to_device_supported_dtype(dt, dev):
210118
has_fp16 = dev.has_aspect_fp16
211119
has_fp64 = dev.has_aspect_fp64
@@ -339,8 +247,6 @@ def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
339247
"_find_buf_dtype",
340248
"_find_buf_dtype2",
341249
"_find_inplace_dtype",
342-
"_empty_like_orderK",
343-
"_empty_like_pair_orderK",
344250
"_to_device_supported_dtype",
345251
"_acceptance_fn_default",
346252
"_acceptance_fn_divide",

dpctl/tests/elementwise/test_type_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import dpctl
2020
import dpctl.tensor as dpt
21+
import dpctl.tensor._copy_utils as cu
2122
import dpctl.tensor._type_utils as tu
2223

2324
from .utils import _all_dtypes, _map_to_device_dtype
@@ -73,15 +74,15 @@ def test_type_utils_empty_like_orderK():
7374
a = dpt.empty((10, 10), dtype=dpt.int32, order="F")
7475
except dpctl.SyclDeviceCreationError:
7576
pytest.skip("No SYCL devices available")
76-
X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device)
77+
X = cu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device)
7778
assert X.flags["F"]
7879

7980

8081
def test_type_utils_empty_like_orderK_invalid_args():
8182
with pytest.raises(TypeError):
82-
tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None)
83+
cu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None)
8384
with pytest.raises(TypeError):
84-
tu._empty_like_pair_orderK(
85+
cu._empty_like_pair_orderK(
8586
[1, 2, 3],
8687
(
8788
1,
@@ -98,7 +99,7 @@ def test_type_utils_empty_like_orderK_invalid_args():
9899
except dpctl.SyclDeviceCreationError:
99100
pytest.skip("No SYCL devices available")
100101
with pytest.raises(TypeError):
101-
tu._empty_like_pair_orderK(
102+
cu._empty_like_pair_orderK(
102103
a,
103104
(
104105
1,

0 commit comments

Comments
 (0)