Skip to content

Commit 18bb612

Browse files
Merge pull request #1133 from vlad-perevezentsev/add_dpctl.tensor.isdtype
Implement dpctl.tensor.isdtype
2 parents 1a6bba0 + 36233fd commit 18bb612

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
int16,
5252
int32,
5353
int64,
54+
isdtype,
5455
uint8,
5556
uint16,
5657
uint32,
@@ -125,6 +126,7 @@
125126
"tril",
126127
"triu",
127128
"dtype",
129+
"isdtype",
128130
"bool",
129131
"int8",
130132
"uint8",

dpctl/tensor/_data_types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,52 @@
3131
complex64 = dtype("complex64")
3232
complex128 = dtype("complex128")
3333

34+
35+
def isdtype(dtype_, kind):
36+
"""isdtype(dtype, kind)
37+
38+
Returns a boolean indicating whether a provided `dtype` is
39+
of a specified data type `kind`.
40+
41+
See [array API](array_api) for more information.
42+
43+
[array_api]: https://data-apis.org/array-api/latest/
44+
"""
45+
46+
if not isinstance(dtype_, dtype):
47+
raise TypeError("Expected instance of `dpt.dtype`, got {dtype_}")
48+
49+
if isinstance(kind, dtype):
50+
return dtype_ == kind
51+
52+
elif isinstance(kind, str):
53+
if kind == "bool":
54+
return dtype_ == dtype("bool")
55+
elif kind == "signed integer":
56+
return dtype_.kind == "i"
57+
elif kind == "unsigned integer":
58+
return dtype_.kind == "u"
59+
elif kind == "integral":
60+
return dtype_.kind in "iu"
61+
elif kind == "real floating":
62+
return dtype_.kind == "f"
63+
elif kind == "complex floating":
64+
return dtype_.kind == "c"
65+
elif kind == "numeric":
66+
return dtype_.kind in "iufc"
67+
else:
68+
raise ValueError(f"Unrecognized data type kind: {kind}")
69+
70+
elif isinstance(kind, tuple):
71+
return any(isdtype(dtype_, k) for k in kind)
72+
73+
else:
74+
raise TypeError(f"Unsupported data type kind: {kind}")
75+
76+
3477
__all__ = [
3578
"dtype",
79+
"isdtype",
3680
"bool",
3781
"int8",
3882
"uint8",
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import pytest
19+
20+
import dpctl.tensor as dpt
21+
22+
list_dtypes = [
23+
"bool",
24+
"int8",
25+
"int16",
26+
"int32",
27+
"int64",
28+
"uint8",
29+
"uint16",
30+
"uint32",
31+
"uint64",
32+
"float16",
33+
"float32",
34+
"float64",
35+
"complex64",
36+
"complex128",
37+
]
38+
39+
40+
dtype_categories = {
41+
"bool": ["bool"],
42+
"signed integer": ["int8", "int16", "int32", "int64"],
43+
"unsigned integer": ["uint8", "uint16", "uint32", "uint64"],
44+
"integral": [
45+
"int8",
46+
"int16",
47+
"int32",
48+
"int64",
49+
"uint8",
50+
"uint16",
51+
"uint32",
52+
"uint64",
53+
],
54+
"real floating": ["float16", "float32", "float64"],
55+
"complex floating": ["complex64", "complex128"],
56+
"numeric": [d for d in list_dtypes if d != "bool"],
57+
}
58+
59+
60+
@pytest.mark.parametrize("kind_str", dtype_categories.keys())
61+
@pytest.mark.parametrize("dtype_str", list_dtypes)
62+
def test_isdtype_kind_str(dtype_str, kind_str):
63+
dt = dpt.dtype(dtype_str)
64+
is_in_kind = dpt.isdtype(dt, kind_str)
65+
expected = dtype_str in dtype_categories[kind_str]
66+
assert is_in_kind == expected
67+
68+
69+
@pytest.mark.parametrize("dtype_str", list_dtypes)
70+
def test_isdtype_kind_tuple(dtype_str):
71+
dt = dpt.dtype(dtype_str)
72+
if dtype_str.startswith("bool"):
73+
assert dpt.isdtype(dt, ("real floating", "bool"))
74+
assert not dpt.isdtype(
75+
dt, ("integral", "real floating", "complex floating")
76+
)
77+
elif dtype_str.startswith("int"):
78+
assert dpt.isdtype(dt, ("real floating", "signed integer"))
79+
assert not dpt.isdtype(
80+
dt, ("bool", "unsigned integer", "real floating")
81+
)
82+
elif dtype_str.startswith("uint"):
83+
assert dpt.isdtype(dt, ("bool", "unsigned integer"))
84+
assert not dpt.isdtype(dt, ("real floating", "complex floating"))
85+
elif dtype_str.startswith("float"):
86+
assert dpt.isdtype(dt, ("complex floating", "real floating"))
87+
assert not dpt.isdtype(dt, ("integral", "complex floating", "bool"))
88+
else:
89+
assert dpt.isdtype(dt, ("integral", "complex floating"))
90+
assert not dpt.isdtype(dt, ("bool", "integral", "real floating"))
91+
92+
93+
@pytest.mark.parametrize("dtype_str", list_dtypes)
94+
def test_isdtype_kind_tuple_dtypes(dtype_str):
95+
dt = dpt.dtype(dtype_str)
96+
if dtype_str.startswith("bool"):
97+
assert dpt.isdtype(dt, (dpt.int32, dpt.bool))
98+
assert not dpt.isdtype(dt, (dpt.int16, dpt.uint32, dpt.float64))
99+
100+
elif dtype_str.startswith("int"):
101+
assert dpt.isdtype(dt, (dpt.int8, dpt.int16, dpt.int32, dpt.int64))
102+
assert not dpt.isdtype(dt, (dpt.bool, dpt.float32, dpt.complex64))
103+
104+
elif dtype_str.startswith("uint"):
105+
assert dpt.isdtype(dt, (dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64))
106+
assert not dpt.isdtype(dt, (dpt.bool, dpt.int32, dpt.float32))
107+
108+
elif dtype_str.startswith("float"):
109+
assert dpt.isdtype(dt, (dpt.float16, dpt.float32, dpt.float64))
110+
assert not dpt.isdtype(dt, (dpt.bool, dpt.complex64, dpt.int8))
111+
112+
else:
113+
assert dpt.isdtype(dt, (dpt.complex64, dpt.complex128))
114+
assert not dpt.isdtype(dt, (dpt.bool, dpt.uint64, dpt.int8))
115+
116+
117+
@pytest.mark.parametrize(
118+
"kind",
119+
[
120+
[dpt.int32, dpt.bool],
121+
"f4",
122+
float,
123+
123,
124+
"complex",
125+
],
126+
)
127+
def test_isdtype_invalid_kind(kind):
128+
with pytest.raises((TypeError, ValueError)):
129+
dpt.isdtype(dpt.int32, kind)

0 commit comments

Comments
 (0)