Skip to content

Commit 352d874

Browse files
Extending dpctl.device_context with nested contexts (#678)
* Enable device_context support numba-dppy offloading * Update dpctl/_sycl_queue_manager.pyx * Add nested_context_factories registry * Add tests for registering nested context factory * Add docs for registering nested contexts * Update CHANGELOG * Use `[Unreleased]` in CHANGELOG * More rainy day tests * Use CPU for tests Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent ae76671 commit 352d874

File tree

4 files changed

+115
-2
lines changed

4 files changed

+115
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [Unreleased]
8+
79
## [0.11.2] - 11/xx/2021
810

11+
### Added
12+
- Extending `dpctl.device_context` with nested contexts (#678)
13+
914
## Fixed
1015
- Fixed issue #649 about incorrect behavior of `.T` method on sliced arrays (#653)
1116

dpctl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
get_current_queue,
6262
get_num_activated_queues,
6363
is_in_device_context,
64+
nested_context_factories,
6465
set_global_queue,
6566
)
6667

@@ -111,6 +112,7 @@
111112
"get_current_queue",
112113
"get_num_activated_queues",
113114
"is_in_device_context",
115+
"nested_context_factories",
114116
"set_global_queue",
115117
]
116118
__all__ += [

dpctl/_sycl_queue_manager.pyx

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# cython: linetrace=True
2020

2121
import logging
22-
from contextlib import contextmanager
22+
from contextlib import ExitStack, contextmanager
2323

2424
from .enum_types import backend_type, device_type
2525

@@ -210,6 +210,22 @@ cpdef get_current_backend():
210210
return _mgr.get_current_backend()
211211

212212

213+
nested_context_factories = []
214+
215+
216+
def _get_nested_contexts(ctxt):
217+
_help_numba_dppy()
218+
return (factory(ctxt) for factory in nested_context_factories)
219+
220+
221+
def _help_numba_dppy():
222+
"""Import numba-dppy for registering nested contexts"""
223+
try:
224+
import numba_dppy
225+
except Exception:
226+
pass
227+
228+
213229
@contextmanager
214230
def device_context(arg):
215231
"""
@@ -222,6 +238,9 @@ def device_context(arg):
222238
the context manager's scope. The yielded queue is removed as the currently
223239
usable queue on exiting the context manager.
224240
241+
You can register context factory in the list of factories.
242+
This context manager uses context factories to create and activate nested contexts.
243+
225244
Args:
226245
227246
queue_str (str) : A string corresponding to the DPC++ filter selector.
@@ -243,11 +262,26 @@ def device_context(arg):
243262
with dpctl.device_context("level0:gpu:0"):
244263
pass
245264
265+
The following example registers nested context factory:
266+
267+
.. code-block:: python
268+
269+
import dctl
270+
271+
def factory(sycl_queue):
272+
...
273+
return context
274+
275+
dpctl.nested_context_factories.append(factory)
276+
246277
"""
247278
ctxt = None
248279
try:
249280
ctxt = _mgr._set_as_current_queue(arg)
250-
yield ctxt
281+
with ExitStack() as stack:
282+
for nested_context in _get_nested_contexts(ctxt):
283+
stack.enter_context(nested_context)
284+
yield ctxt
251285
finally:
252286
# Code to release resource
253287
if ctxt:

dpctl/tests/test_sycl_queue_manager.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"""Defines unit test cases for the SyclQueueManager class.
1818
"""
1919

20+
import contextlib
21+
2022
import pytest
2123

2224
import dpctl
@@ -156,3 +158,73 @@ def test_get_current_backend():
156158
dpctl.set_global_queue("gpu")
157159
elif has_cpu():
158160
dpctl.set_global_queue("cpu")
161+
162+
163+
def test_nested_context_factory_is_empty_list():
164+
assert isinstance(dpctl.nested_context_factories, list)
165+
assert not dpctl.nested_context_factories
166+
167+
168+
@contextlib.contextmanager
169+
def _register_nested_context_factory(factory):
170+
dpctl.nested_context_factories.append(factory)
171+
try:
172+
yield
173+
finally:
174+
dpctl.nested_context_factories.remove(factory)
175+
176+
177+
def test_register_nested_context_factory_context():
178+
def factory():
179+
pass
180+
181+
with _register_nested_context_factory(factory):
182+
assert factory in dpctl.nested_context_factories
183+
184+
assert isinstance(dpctl.nested_context_factories, list)
185+
assert not dpctl.nested_context_factories
186+
187+
188+
@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
189+
def test_device_context_activates_nested_context():
190+
in_context = False
191+
factory_called = False
192+
193+
@contextlib.contextmanager
194+
def context():
195+
nonlocal in_context
196+
old, in_context = in_context, True
197+
yield
198+
in_context = old
199+
200+
def factory(_):
201+
nonlocal factory_called
202+
factory_called = True
203+
return context()
204+
205+
with _register_nested_context_factory(factory):
206+
assert not factory_called
207+
assert not in_context
208+
209+
with dpctl.device_context("opencl:cpu:0"):
210+
assert factory_called
211+
assert in_context
212+
213+
assert not in_context
214+
215+
216+
@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
217+
@pytest.mark.parametrize(
218+
"factory, exception, match",
219+
[
220+
(True, TypeError, "object is not callable"),
221+
(lambda x: None, AttributeError, "no attribute '__exit__'"),
222+
],
223+
)
224+
def test_nested_context_factory_exception_if_wrong_factory(
225+
factory, exception, match
226+
):
227+
with pytest.raises(exception, match=match):
228+
with _register_nested_context_factory(factory):
229+
with dpctl.device_context("opencl:cpu:0"):
230+
pass

0 commit comments

Comments
 (0)