Skip to content

Cached device store #1076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dpctl/_sycl_queue_manager.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# distutils: language = c++
# cython: language_level=3

from ._sycl_device cimport SyclDevice
from ._sycl_queue cimport SyclQueue


cpdef SyclQueue get_current_queue()
cpdef get_current_device_type ()
cpdef get_current_backend()

cpdef object get_device_cached_queue(object)
45 changes: 45 additions & 0 deletions dpctl/_sycl_queue_manager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import logging
from contextlib import ExitStack, contextmanager
from contextvars import ContextVar

from .enum_types import backend_type, device_type

Expand All @@ -35,6 +36,7 @@ from ._backend cimport ( # noqa: E211
_device_type,
)
from ._sycl_context cimport SyclContext
from ._sycl_device cimport SyclDevice

__all__ = [
"device_context",
Expand All @@ -44,6 +46,7 @@ __all__ = [
"get_num_activated_queues",
"is_in_device_context",
"set_global_queue",
"_global_device_queue_cache",
]

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -291,3 +294,45 @@ def device_context(arg):
_mgr._remove_current_queue()
else:
_logger.debug("No queue was created so nothing to do")


cdef class _DeviceDefaultQueueCache:
cdef dict __device_queue_map__

def __cinit__(self):
self.__device_queue_map__ = dict()

def get_or_create(self, key):
"""Return instance of SyclQueue and indicator if cache has been modified"""
if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], SyclContext) and isinstance(key[1], SyclDevice):
ctx_dev = key
q = None
elif isinstance(key, SyclDevice):
q = SyclQueue(key)
ctx_dev = q.sycl_context, key
else:
raise TypeError
if ctx_dev in self.__device_queue_map__:
return self.__device_queue_map__[ctx_dev], False
if q is None: q = SyclQueue(*ctx_dev)
self.__device_queue_map__[ctx_dev] = q
return q, True

cdef _update_map(self, dev_queue_map):
self.__device_queue_map__.update(dev_queue_map)

def __copy__(self):
cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__(_DeviceDefaultQueueCache)
_copy._update_map(self.__device_queue_map__)
return _copy


_global_device_queue_cache = ContextVar('global_device_queue_cache', default=_DeviceDefaultQueueCache())


cpdef object get_device_cached_queue(object key):
"""Get cached queue associated with given device"""
_cache = _global_device_queue_cache.get()
q_, changed_ = _cache.get_or_create(key)
if changed_: _global_device_queue_cache.set(_cache)
return q_
3 changes: 2 additions & 1 deletion dpctl/memory/_memory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ from dpctl._backend cimport ( # noqa: E211
from .._sycl_context cimport SyclContext
from .._sycl_device cimport SyclDevice
from .._sycl_queue cimport SyclQueue
from .._sycl_queue_manager cimport get_device_cached_queue

import collections
import numbers
Expand Down Expand Up @@ -150,7 +151,7 @@ cdef class _Memory:

if (nbytes > 0):
if queue is None:
queue = dpctl.SyclQueue()
queue = get_device_cached_queue(dpctl.SyclDevice())

QRef = queue.get_queue_ref()
if (ptr_type == b"shared"):
Expand Down
9 changes: 3 additions & 6 deletions dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dpctl
from dpctl._sycl_queue_manager import get_device_cached_queue

__doc__ = "Implementation of array API mandated Device class"

Expand Down Expand Up @@ -60,9 +61,7 @@ def create_device(cls, dev):
elif isinstance(dev, dpctl.SyclDevice):
par = dev.parent_device
if par is None:
if dev not in cls.__device_queue_map__:
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
obj.sycl_queue_ = cls.__device_queue_map__[dev]
obj.sycl_queue_ = get_device_cached_queue(dev)
else:
raise ValueError(
f"Using non-root device {dev} to specify offloading "
Expand All @@ -74,9 +73,7 @@ def create_device(cls, dev):
_dev = dpctl.SyclDevice()
else:
_dev = dpctl.SyclDevice(dev)
if _dev not in cls.__device_queue_map__:
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
obj.sycl_queue_ = get_device_cached_queue(_dev)
return obj

@property
Expand Down
9 changes: 5 additions & 4 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t

cimport dpctl as c_dpctl
cimport dpctl.memory as c_dpmem
from dpctl._sycl_queue_manager cimport get_device_cached_queue

from .._backend cimport (
DPCTLDevice_Delete,
Expand Down Expand Up @@ -344,12 +345,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
if _IS_LINUX:
default_context = root_device.sycl_platform.default_context
else:
default_context = dpctl.SyclQueue(root_device).sycl_context
default_context = get_device_cached_queue(root_device).sycl_context
except RuntimeError:
default_context = dpctl.SyclQueue(root_device).sycl_context
default_context = get_device_cached_queue(root_device).sycl_context
if dlm_tensor.dl_tensor.data is NULL:
usm_type = b"device"
q = dpctl.SyclQueue(default_context, root_device)
q = get_device_cached_queue((default_context, root_device,))
else:
usm_type = c_dpmem._Memory.get_pointer_type(
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
Expand All @@ -364,7 +365,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
<c_dpctl.SyclContext>default_context
)
q = dpctl.SyclQueue(default_context, alloc_device)
q = get_device_cached_queue((default_context, alloc_device,))
if dlm_tensor.dl_tensor.dtype.bits % 8:
raise BufferError(
"Can not import DLPack tensor whose element's "
Expand Down
19 changes: 19 additions & 0 deletions dpctl/tests/test_sycl_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,22 @@ def test_nested_context_factory_exception_if_wrong_factory(
with _register_nested_context_factory(factory):
with dpctl.device_context("opencl:cpu:0"):
pass


def test__DeviceDefaultQueueCache():
import copy

from dpctl._sycl_queue_manager import _global_device_queue_cache as cache
from dpctl._sycl_queue_manager import get_device_cached_queue

try:
d = dpctl.SyclDevice()
except dpctl.SyclDeviceCreationError:
pytest.skip("Could not create default device")

q1 = get_device_cached_queue(d)
cache_copy = copy.copy(cache.get())
q2, changed = cache_copy.get_or_create(d)

assert not changed
assert q1 == q2
22 changes: 22 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ def test_dlpack_exporter(typestr, usm_type):
assert caps_fn(caps2, b"dltensor")


def test_dlpack_exporter_empty(typestr, usm_type):
caps_fn = ctypes.pythonapi.PyCapsule_IsValid
caps_fn.restype = bool
caps_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
sycl_dev = dpctl.select_default_device()
skip_if_dtype_not_supported(typestr, sycl_dev)
X = dpt.empty((0,), dtype=typestr, usm_type=usm_type, device=sycl_dev)
caps = X.__dlpack__()
assert caps_fn(caps, b"dltensor")
Y = dpt.empty(
(
1,
0,
),
dtype=typestr,
usm_type=usm_type,
device=sycl_dev,
)
caps = Y.__dlpack__()
assert caps_fn(caps, b"dltensor")


def test_dlpack_exporter_stream():
try:
q1 = dpctl.SyclQueue()
Expand Down