Skip to content

Commit a0bee90

Browse files
Closes gh-1350
dpctl.tensor.asarray implementation of order='K' processing was replaced with tested _empty_like_orderK utility to fix the issue reported in gh-1350. Few routines had to be shuffled to avoid import failure due to circular import dependencies.
1 parent 9afdb58 commit a0bee90

File tree

3 files changed

+65
-43
lines changed

3 files changed

+65
-43
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._tensor_impl as ti
2626
import dpctl.utils
27-
from dpctl.tensor._ctors import _get_dtype
27+
from dpctl.tensor._data_types import _get_dtype
2828
from dpctl.tensor._device import normalize_queue_device
2929

3030
__doc__ = (
@@ -354,11 +354,11 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
354354
range(X.ndim), key=lambda d: builtins.abs(st[d]), reverse=True
355355
)
356356
inv_perm = sorted(range(X.ndim), key=lambda i: perm[i])
357-
st_sorted = [st[i] for i in perm]
358357
sh = X.shape
359358
sh_sorted = tuple(sh[i] for i in perm)
360359
R = dpt.empty(sh_sorted, dtype=dt, usm_type=usm_type, device=dev, order="C")
361-
if min(st_sorted) < 0:
360+
if min(st) < 0:
361+
st_sorted = [st[i] for i in perm]
362362
sl = tuple(
363363
slice(None, None, -1)
364364
if st_sorted[i] < 0

dpctl/tensor/_ctors.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import dpctl.tensor as dpt
2424
import dpctl.tensor._tensor_impl as ti
2525
import dpctl.utils
26+
from dpctl.tensor._copy_utils import _empty_like_orderK
27+
from dpctl.tensor._data_types import _get_dtype
2628
from dpctl.tensor._device import normalize_queue_device
2729
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol
2830

@@ -32,24 +34,6 @@
3234
_host_set = frozenset([None])
3335

3436

35-
def _get_dtype(dtype, sycl_obj, ref_type=None):
36-
if dtype is None:
37-
if ref_type in [None, float] or np.issubdtype(ref_type, np.floating):
38-
dtype = ti.default_device_fp_type(sycl_obj)
39-
return dpt.dtype(dtype)
40-
if ref_type in [bool, np.bool_]:
41-
dtype = ti.default_device_bool_type(sycl_obj)
42-
return dpt.dtype(dtype)
43-
if ref_type is int or np.issubdtype(ref_type, np.integer):
44-
dtype = ti.default_device_int_type(sycl_obj)
45-
return dpt.dtype(dtype)
46-
if ref_type is complex or np.issubdtype(ref_type, np.complexfloating):
47-
dtype = ti.default_device_complex_type(sycl_obj)
48-
return dpt.dtype(dtype)
49-
raise TypeError(f"Reference type {ref_type} not recognized.")
50-
return dpt.dtype(dtype)
51-
52-
5337
def _array_info_dispatch(obj):
5438
if isinstance(obj, dpt.usm_ndarray):
5539
return obj.shape, obj.dtype, frozenset([obj.sycl_queue])
@@ -106,6 +90,21 @@ def _array_info_sequence(li):
10690
return (n,) + dim, dt, device
10791

10892

93+
def _c_contig_strides(sh):
94+
if not sh:
95+
return tuple()
96+
el = 1
97+
n = len(sh)
98+
res = [
99+
el,
100+
] * n
101+
for i in range(n - 1, 0, -1):
102+
el *= sh[i]
103+
res[i - 1] = el
104+
105+
return tuple(res)
106+
107+
109108
def _asarray_from_usm_ndarray(
110109
usm_ndary,
111110
dtype=None,
@@ -162,28 +161,7 @@ def _asarray_from_usm_ndarray(
162161
order = "C" if c_contig else "F"
163162
if order == "K":
164163
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
165-
# new USM allocation
166-
res = dpt.usm_ndarray(
167-
usm_ndary.shape,
168-
dtype=dtype,
169-
buffer=usm_type,
170-
order="C",
171-
buffer_ctor_kwargs={"queue": copy_q},
172-
)
173-
original_strides = usm_ndary.strides
174-
ind = sorted(
175-
range(usm_ndary.ndim),
176-
key=lambda i: abs(original_strides[i]),
177-
reverse=True,
178-
)
179-
new_strides = tuple(res.strides[ind[i]] for i in ind)
180-
# reuse previously made USM allocation
181-
res = dpt.usm_ndarray(
182-
usm_ndary.shape,
183-
dtype=res.dtype,
184-
buffer=res.usm_data,
185-
strides=new_strides,
186-
)
164+
res = _empty_like_orderK(usm_ndary, dtype, usm_type, copy_q)
187165
else:
188166
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
189167
res = dpt.usm_ndarray(

dpctl/tensor/_data_types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,25 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from numpy import bool_ as np_bool_
18+
from numpy import complexfloating as np_complexfloating
1719
from numpy import dtype
20+
from numpy import floating as np_floating
21+
from numpy import integer as np_integer
22+
from numpy import issubdtype as np_issubdtype
23+
24+
from dpctl.tensor._tensor_impl import (
25+
default_device_bool_type as ti_default_device_bool_type,
26+
)
27+
from dpctl.tensor._tensor_impl import (
28+
default_device_complex_type as ti_default_device_complex_type,
29+
)
30+
from dpctl.tensor._tensor_impl import (
31+
default_device_fp_type as ti_default_device_fp_type,
32+
)
33+
from dpctl.tensor._tensor_impl import (
34+
default_device_int_type as ti_default_device_int_type,
35+
)
1836

1937
bool = dtype("bool")
2038
int8 = dtype("int8")
@@ -74,6 +92,32 @@ def isdtype(dtype_, kind):
7492
raise TypeError(f"Unsupported data type kind: {kind}")
7593

7694

95+
def _get_dtype(inp_dt, sycl_obj, ref_type=None):
96+
"""
97+
Type inference utility to construct data type
98+
object with defaults based on reference type.
99+
100+
_get_dtype is used by dpctl.tensor.asarray
101+
to infer data type of the output array from the
102+
input sequence.
103+
"""
104+
if inp_dt is None:
105+
if ref_type in [None, float] or np_issubdtype(ref_type, np_floating):
106+
fp_dt = ti_default_device_fp_type(sycl_obj)
107+
return dtype(fp_dt)
108+
if ref_type in [bool, np_bool_]:
109+
bool_dt = ti_default_device_bool_type(sycl_obj)
110+
return dtype(bool_dt)
111+
if ref_type is int or np_issubdtype(ref_type, np_integer):
112+
int_dt = ti_default_device_int_type(sycl_obj)
113+
return dtype(int_dt)
114+
if ref_type is complex or np_issubdtype(ref_type, np_complexfloating):
115+
cfp_dt = ti_default_device_complex_type(sycl_obj)
116+
return dtype(cfp_dt)
117+
raise TypeError(f"Reference type {ref_type} not recognized.")
118+
return dtype(inp_dt)
119+
120+
77121
__all__ = [
78122
"dtype",
79123
"isdtype",

0 commit comments

Comments
 (0)