Skip to content

Extending dpctl.device_context with nested contexts #678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 25, 2021
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


## [0.11.2] - 11/xx/2021

### Added
- Extending `dpctl.device_context` with nested contexts (#678)

## Fixed
- Fixed issue #649 about incorrect behavior of `.T` method on sliced arrays (#653)

Expand Down
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
get_current_queue,
get_num_activated_queues,
is_in_device_context,
nested_context_factories,
set_global_queue,
)

Expand Down Expand Up @@ -111,6 +112,7 @@
"get_current_queue",
"get_num_activated_queues",
"is_in_device_context",
"nested_context_factories",
Copy link
Contributor

@oleksandr-pavlyk oleksandr-pavlyk Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see why this needs to be exposed, especially since device_context is on its way out.
If it were to stay we should need to design an API to manage these factories.

I.e. hold then in a dictionary and provide functions to retrieve the list of keys for the registered factories, to add a factory to this list with some key, to remove a factory by its key.

Even then, these function would belong to dpctl.utils namespace, rather than dpctl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PokhodenkoSA If we can avoid adding the function here, let us do it. It just adds more clean up for future.

@oleksandr-pavlyk Given that the changes need to make it to 0.12, I am willing to let it slide. Once, device_context is gone a whole lot of things will have to be deprecated and removed. Basically, all the functions (get_current_queue and family) above nested_context_factories.

Copy link
Contributor Author

@PokhodenkoSA PokhodenkoSA Nov 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. hold then in a dictionary and provide functions to retrieve the list of keys for the registered factories, to add a factory to this list with some key, to remove a factory by its key.

The factory is a key. So we can use list. Also I think we should not guaranty order of contexts.
Agree about functions. But I do not want to make API, especially when users will only append nested context once and will not remove it at all - that is the main use case.

belong to dpctl.utils namespace, rather than dpctl

Partially agree that this is utils. But I think this should be close to device_context namespace. Ideally: dpctl.device_context.add_nested_context_factory(factory).

Once, device_context is gone a whole lot of things will have to be deprecated and removed.

Separate discussion needed: How to make numba redirect context active for compute follows data?

Basically, all the functions (get_current_queue and family) above nested_context_factories.

Agree. As they are located close to each other, so clean up will be easy and in the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can avoid adding the function here, let us do it.

I think list is simple enough. @oleksandr-pavlyk let me keep things as is. I know that API should be good isolation. But in this case I would like to be simple.

"set_global_queue",
]
__all__ += [
Expand Down
38 changes: 36 additions & 2 deletions dpctl/_sycl_queue_manager.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# cython: linetrace=True

import logging
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager

from .enum_types import backend_type, device_type

Expand Down Expand Up @@ -210,6 +210,22 @@ cpdef get_current_backend():
return _mgr.get_current_backend()


nested_context_factories = []


def _get_nested_contexts(ctxt):
_help_numba_dppy()
return (factory(ctxt) for factory in nested_context_factories)


def _help_numba_dppy():
"""Import numba-dppy for registering nested contexts"""
try:
import numba_dppy
except Exception:
pass


@contextmanager
def device_context(arg):
"""
Expand All @@ -222,6 +238,9 @@ def device_context(arg):
the context manager's scope. The yielded queue is removed as the currently
usable queue on exiting the context manager.

You can register context factory in the list of factories.
This context manager uses context factories to create and activate nested contexts.

Args:

queue_str (str) : A string corresponding to the DPC++ filter selector.
Expand All @@ -243,11 +262,26 @@ def device_context(arg):
with dpctl.device_context("level0:gpu:0"):
pass

The following example registers nested context factory:

.. code-block:: python

import dctl

def factory(sycl_queue):
...
return context

dpctl.nested_context_factories.append(factory)

"""
ctxt = None
try:
ctxt = _mgr._set_as_current_queue(arg)
yield ctxt
with ExitStack() as stack:
for nested_context in _get_nested_contexts(ctxt):
stack.enter_context(nested_context)
yield ctxt
finally:
# Code to release resource
if ctxt:
Expand Down
72 changes: 72 additions & 0 deletions dpctl/tests/test_sycl_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""Defines unit test cases for the SyclQueueManager class.
"""

import contextlib

import pytest

import dpctl
Expand Down Expand Up @@ -156,3 +158,73 @@ def test_get_current_backend():
dpctl.set_global_queue("gpu")
elif has_cpu():
dpctl.set_global_queue("cpu")


def test_nested_context_factory_is_empty_list():
assert isinstance(dpctl.nested_context_factories, list)
assert not dpctl.nested_context_factories


@contextlib.contextmanager
def _register_nested_context_factory(factory):
dpctl.nested_context_factories.append(factory)
try:
yield
finally:
dpctl.nested_context_factories.remove(factory)


def test_register_nested_context_factory_context():
def factory():
pass

with _register_nested_context_factory(factory):
assert factory in dpctl.nested_context_factories

assert isinstance(dpctl.nested_context_factories, list)
assert not dpctl.nested_context_factories


@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
def test_device_context_activates_nested_context():
in_context = False
factory_called = False

@contextlib.contextmanager
def context():
nonlocal in_context
old, in_context = in_context, True
yield
in_context = old

def factory(_):
nonlocal factory_called
factory_called = True
return context()

with _register_nested_context_factory(factory):
assert not factory_called
assert not in_context

with dpctl.device_context("opencl:cpu:0"):
assert factory_called
assert in_context

assert not in_context


@pytest.mark.skipif(not has_cpu(), reason="No OpenCL CPU queues available")
@pytest.mark.parametrize(
"factory, exception, match",
[
(True, TypeError, "object is not callable"),
(lambda x: None, AttributeError, "no attribute '__exit__'"),
],
)
def test_nested_context_factory_exception_if_wrong_factory(
factory, exception, match
):
with pytest.raises(exception, match=match):
with _register_nested_context_factory(factory):
with dpctl.device_context("opencl:cpu:0"):
pass