Skip to content

Commit a3c95a5

Browse files
committed
Added utility functions, basic where tests
- Utility functions are for finding an output type for universal and binary functions when the device of allocation lacks fp16 or fp64
1 parent 7ae6a1a commit a3c95a5

File tree

3 files changed

+217
-1
lines changed

3 files changed

+217
-1
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@
1919
import dpctl.tensor._tensor_impl as ti
2020
from dpctl.tensor._manipulation_functions import _broadcast_shapes
2121

22+
from ._type_utils import _all_data_types, _can_cast
23+
24+
25+
def _where_result_type(dt1, dt2, dev):
26+
res_dtype = dpt.result_type(dt1, dt2)
27+
fp16 = dev.has_aspect_fp16
28+
fp64 = dev.has_aspect_fp64
29+
30+
all_dts = _all_data_types(fp16, fp64)
31+
if res_dtype in all_dts:
32+
return res_dtype
33+
else:
34+
for res_dtype_ in all_dts:
35+
if _can_cast(dt1, res_dtype_, fp16, fp64) and _can_cast(
36+
dt2, res_dtype_, fp16, fp64
37+
):
38+
return res_dtype_
39+
return None
40+
2241

2342
def where(condition, x1, x2):
2443
if not isinstance(condition, dpt.usm_ndarray):
@@ -52,7 +71,7 @@ def where(condition, x1, x2):
5271

5372
x1_dtype = x1.dtype
5473
x2_dtype = x2.dtype
55-
dst_dtype = dpt.result_type(x1.dtype, x2.dtype)
74+
dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
5675

5776
if condition.size == 0:
5877
return dpt.asarray(

dpctl/tensor/_type_utils.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl.tensor as dpt
18+
19+
20+
def _all_data_types(_fp16, _fp64):
21+
if _fp64:
22+
if _fp16:
23+
return [
24+
dpt.bool,
25+
dpt.int8,
26+
dpt.uint8,
27+
dpt.int16,
28+
dpt.uint16,
29+
dpt.int32,
30+
dpt.uint32,
31+
dpt.int64,
32+
dpt.uint64,
33+
dpt.float16,
34+
dpt.float32,
35+
dpt.float64,
36+
dpt.complex64,
37+
dpt.complex128,
38+
]
39+
else:
40+
return [
41+
dpt.bool,
42+
dpt.int8,
43+
dpt.uint8,
44+
dpt.int16,
45+
dpt.uint16,
46+
dpt.int32,
47+
dpt.uint32,
48+
dpt.int64,
49+
dpt.uint64,
50+
dpt.float32,
51+
dpt.float64,
52+
dpt.complex64,
53+
dpt.complex128,
54+
]
55+
else:
56+
if _fp16:
57+
return [
58+
dpt.bool,
59+
dpt.int8,
60+
dpt.uint8,
61+
dpt.int16,
62+
dpt.uint16,
63+
dpt.int32,
64+
dpt.uint32,
65+
dpt.int64,
66+
dpt.uint64,
67+
dpt.float16,
68+
dpt.float32,
69+
dpt.complex64,
70+
]
71+
else:
72+
return [
73+
dpt.bool,
74+
dpt.int8,
75+
dpt.uint8,
76+
dpt.int16,
77+
dpt.uint16,
78+
dpt.int32,
79+
dpt.uint32,
80+
dpt.int64,
81+
dpt.uint64,
82+
dpt.float32,
83+
dpt.complex64,
84+
]
85+
86+
87+
def is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
88+
"""
89+
Return True if data type `dt` is the
90+
maximal size inexact data type
91+
"""
92+
if _fp64:
93+
return dt in [dpt.float64, dpt.complex128]
94+
return dt in [dpt.float32, dpt.complex64]
95+
96+
97+
def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
98+
"""
99+
Can `from_` be cast to `to_` safely on a device with
100+
fp16 and fp64 aspects as given?
101+
"""
102+
can_cast_v = dpt.can_cast(from_, to_) # ask NumPy
103+
if _fp16 and _fp64:
104+
return can_cast_v
105+
if not can_cast_v:
106+
if (
107+
from_.kind in "biu"
108+
and to_.kind in "fc"
109+
and is_maximal_inexact_type(to_, _fp16, _fp64)
110+
):
111+
return True
112+
113+
return can_cast_v
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
import pytest
19+
from helper import get_queue_or_skip, skip_if_dtype_not_supported
20+
from numpy.testing import assert_array_equal
21+
22+
import dpctl.tensor as dpt
23+
24+
_all_dtypes = [
25+
"u1",
26+
"i1",
27+
"u2",
28+
"i2",
29+
"u4",
30+
"i4",
31+
"u8",
32+
"i8",
33+
"e",
34+
"f",
35+
"d",
36+
"F",
37+
"D",
38+
]
39+
40+
41+
def test_where_basic():
42+
get_queue_or_skip
43+
44+
cond = dpt.asarray(
45+
[
46+
[True, False, False],
47+
[False, True, False],
48+
[False, False, True],
49+
[False, False, False],
50+
[True, True, True],
51+
]
52+
)
53+
out = dpt.where(cond, dpt.asarray(1), dpt.asarray(0))
54+
out_expected = dpt.asarray(
55+
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]]
56+
)
57+
58+
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
59+
60+
61+
@pytest.mark.parametrize("dt1", _all_dtypes)
62+
@pytest.mark.parametrize("dt2", _all_dtypes)
63+
def test_where_all_dtypes(dt1, dt2):
64+
q = get_queue_or_skip()
65+
skip_if_dtype_not_supported(dt1, q)
66+
skip_if_dtype_not_supported(dt2, q)
67+
68+
cond_np = np.arange(5) > 2
69+
x1_np = np.asarray(2, dtype=dt1)
70+
x2_np = np.asarray(3, dtype=dt2)
71+
72+
cond = dpt.asarray(cond_np, sycl_queue=q)
73+
x1 = dpt.asarray(x1_np, sycl_queue=q)
74+
x2 = dpt.asarray(x2_np, sycl_queue=q)
75+
76+
res = dpt.where(cond, x1, x2)
77+
res_np = np.where(cond_np, x1_np, x2_np)
78+
79+
if res.dtype != res_np.dtype:
80+
assert res.dtype.kind == res_np.dtype.kind
81+
assert_array_equal(dpt.asnumpy(res).astype(res_np.dtype), res_np)
82+
83+
else:
84+
assert_array_equal(dpt.asnumpy(res), res_np)

0 commit comments

Comments
 (0)