Skip to content

Commit e96cce3

Browse files
Merge pull request #797 from IntelPython/get-coerced-usm-type
2 parents c274840 + 99f017d commit e96cce3

File tree

5 files changed

+72
-22
lines changed

5 files changed

+72
-22
lines changed

dpctl/tensor/_device.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class Device:
2929
or ``sycl_device``.
3030
"""
3131

32+
__device_queue_map__ = dict()
33+
3234
def __new__(cls, *args, **kwargs):
3335
raise TypeError("No public constructor")
3436

@@ -55,7 +57,9 @@ def create_device(cls, dev):
5557
elif isinstance(dev, dpctl.SyclDevice):
5658
par = dev.parent_device
5759
if par is None:
58-
obj.sycl_queue_ = dpctl.SyclQueue(dev)
60+
if dev not in cls.__device_queue_map__:
61+
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
62+
obj.sycl_queue_ = cls.__device_queue_map__[dev]
5963
else:
6064
raise ValueError(
6165
"Using non-root device {} to specify offloading "
@@ -64,9 +68,12 @@ def create_device(cls, dev):
6468
)
6569
else:
6670
if dev is None:
67-
obj.sycl_queue_ = dpctl.SyclQueue()
71+
_dev = dpctl.SyclDevice()
6872
else:
69-
obj.sycl_queue_ = dpctl.SyclQueue(dev)
73+
_dev = dpctl.SyclDevice(dev)
74+
if _dev not in cls.__device_queue_map__:
75+
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
76+
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
7077
return obj
7178

7279
@property

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,11 @@ def test_datapi_device():
320320
dev_t()
321321
dev_t.create_device(X.device)
322322
dev_t.create_device(X.sycl_queue)
323-
dev_t.create_device(X.sycl_device)
324-
dev_t.create_device(X.sycl_device.filter_string)
325-
dev_t.create_device(None)
323+
d1 = dev_t.create_device(X.sycl_device)
324+
d2 = dev_t.create_device(X.sycl_device.filter_string)
325+
d3 = dev_t.create_device(None)
326+
assert d1.sycl_queue == d2.sycl_queue
327+
assert d1.sycl_queue == d3.sycl_queue
326328
X.device.sycl_context
327329
X.device.sycl_queue
328330
X.device.sycl_device

dpctl/tests/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,17 @@ def test_get_execution_queue():
5454
q,
5555
)
5656
)
57-
assert exec_q is q
57+
assert exec_q is None
58+
q_c = dpctl.SyclQueue(q._get_capsule())
59+
assert q == q_c
60+
exec_q = dpctl.utils.get_execution_queue(
61+
(
62+
q,
63+
q_c,
64+
q,
65+
)
66+
)
67+
assert exec_q == q
5868

5969

6070
def test_get_execution_queue_nonequiv():
@@ -69,3 +79,18 @@ def test_get_execution_queue_nonequiv():
6979

7080
exec_q = dpctl.utils.get_execution_queue((q, q1, q2))
7181
assert exec_q is None
82+
83+
84+
def test_get_coerced_usm_type():
85+
_t = ["device", "shared", "host"]
86+
87+
for i1 in range(len(_t)):
88+
for i2 in range(len(_t)):
89+
assert (
90+
dpctl.utils.get_coerced_usm_type([_t[i1], _t[i2]])
91+
== _t[min(i1, i2)]
92+
)
93+
94+
assert dpctl.utils.get_coerced_usm_type([]) is None
95+
with pytest.raises(TypeError):
96+
dpctl.utils.get_coerced_usm_type(dict())

dpctl/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
A collection of utility functions.
1919
"""
2020

21-
from ._compute_follows_data import get_execution_queue
21+
from ._compute_follows_data import get_coerced_usm_type, get_execution_queue
2222

2323
__all__ = [
2424
"get_execution_queue",
25+
"get_coerced_usm_type",
2526
]

dpctl/utils/_compute_follows_data.pyx

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,19 @@ import dpctl
2828

2929
from .._sycl_queue cimport SyclQueue
3030

31-
__all__ = ["get_execution_queue", ]
31+
__all__ = ["get_execution_queue", "get_coerced_usm_type"]
3232

3333

3434
cdef bint queue_equiv(SyclQueue q1, SyclQueue q2):
35-
""" Queues are equivalent if contexts are the same,
36-
devices are the same, and properties are the same."""
37-
return (
38-
(q1 is q2) or
39-
(
40-
(q1.sycl_context == q2.sycl_context) and
41-
(q1.sycl_device == q2.sycl_device) and
42-
(q1.is_in_order == q2.is_in_order) and
43-
(q1.has_enable_profiling == q2.has_enable_profiling)
44-
)
45-
)
35+
""" Queues are equivalent if q1 == q2, that is they are copies
36+
of the same underlying SYCL object and hence are the same."""
37+
return q1.__eq__(q2)
4638

4739

4840
def get_execution_queue(qs):
4941
""" Given a list of :class:`dpctl.SyclQueue` objects
5042
returns the execution queue under compute follows data paradigm,
51-
or returns `None` if queues are not equivalent.
43+
or returns `None` if queues are not equal.
5244
"""
5345
if not isinstance(qs, (list, tuple)):
5446
raise TypeError(
@@ -58,11 +50,34 @@ def get_execution_queue(qs):
5850
return None
5951
elif len(qs) == 1:
6052
return qs[0] if isinstance(qs[0], dpctl.SyclQueue) else None
61-
for q1, q2 in zip(qs, qs[1:]):
53+
for q1, q2 in zip(qs[:-1], qs[1:]):
6254
if not isinstance(q1, dpctl.SyclQueue):
6355
return None
6456
elif not isinstance(q2, dpctl.SyclQueue):
6557
return None
6658
elif not queue_equiv(<SyclQueue> q1, <SyclQueue> q2):
6759
return None
6860
return qs[0]
61+
62+
63+
def get_coerced_usm_type(usm_types):
64+
""" Given a list of strings denoting the types of USM allocations
65+
for input arrays returns the type of USM allocation for the output
66+
array(s) per compute follows data paradigm.
67+
Returns `None` if the type can not be deduced."""
68+
if not isinstance(usm_types, (list, tuple)):
69+
raise TypeError(
70+
"Expected a list or a tuple, got {}".format(type(usm_types))
71+
)
72+
if len(usm_types) == 0:
73+
return None
74+
_k = ["device", "shared", "host"]
75+
_m = {k:i for i, k in enumerate(_k)}
76+
res = len(_k)
77+
for t in usm_types:
78+
if not isinstance(t, str):
79+
return None
80+
if t not in _m:
81+
return None
82+
res = min(res, _m[t])
83+
return _k[res]

0 commit comments

Comments
 (0)