Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit d294c80

Browse files
authored
Semantics "with context" (#57)
1 parent 02b504f commit d294c80

File tree

11 files changed

+321
-50
lines changed

11 files changed

+321
-50
lines changed

numba/core/decorators.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,7 @@ def bar(x, y):
148148
target = options.pop('target')
149149
warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning)
150150
else:
151-
target = options.pop('_target', 'cpu')
152-
153-
parallel_option = options.get('parallel')
154-
if isinstance(parallel_option, dict) and parallel_option.get('offload') is True:
155-
from numba.dppl import dppl_offload_dispatcher
156-
target = '__dppl_offload_gpu__'
151+
target = options.pop('_target', None)
157152

158153
options['boundscheck'] = boundscheck
159154

@@ -187,22 +182,8 @@ def bar(x, y):
187182

188183

189184
def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args):
190-
dispatcher = registry.dispatcher_registry[target]
191-
192-
def wrapper(func):
193-
if extending.is_jitted(func):
194-
raise TypeError(
195-
"A jit decorator was called on an already jitted function "
196-
f"{func}. If trying to access the original python "
197-
f"function, use the {func}.py_func attribute."
198-
)
199-
200-
if not inspect.isfunction(func):
201-
raise TypeError(
202-
"The decorated object is not a function (got type "
203-
f"{type(func)})."
204-
)
205185

186+
def wrapper(func, dispatcher):
206187
if config.ENABLE_CUDASIM and target == 'cuda':
207188
from numba import cuda
208189
return cuda.jit(func)
@@ -226,7 +207,33 @@ def wrapper(func):
226207
disp.disable_compile()
227208
return disp
228209

229-
return wrapper
210+
def __wrapper(func):
211+
if extending.is_jitted(func):
212+
raise TypeError(
213+
"A jit decorator was called on an already jitted function "
214+
f"{func}. If trying to access the original python "
215+
f"function, use the {func}.py_func attribute."
216+
)
217+
218+
if not inspect.isfunction(func):
219+
raise TypeError(
220+
"The decorated object is not a function (got type "
221+
f"{type(func)})."
222+
)
223+
224+
if (target == 'npyufunc' or targetoptions.get('no_cpython_wrapper')
225+
or sigs or config.DISABLE_JIT or not targetoptions.get('nopython')):
226+
target_ = target
227+
if target_ is None:
228+
target_ = 'cpu'
229+
disp = registry.dispatcher_registry[target_]
230+
return wrapper(func, disp)
231+
232+
from numba.dppl.target_dispatcher import TargetDispatcher
233+
disp = TargetDispatcher(func, wrapper, target, targetoptions.get('parallel'))
234+
return disp
235+
236+
return __wrapper
230237

231238

232239
def generated_jit(function=None, target='cpu', cache=False,

numba/core/dispatcher.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,14 @@ def _set_uuid(self, u):
673673
self._recent.append(self)
674674

675675

676-
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
676+
import abc
677+
678+
class DispatcherMeta(abc.ABCMeta):
679+
def __instancecheck__(self, other):
680+
return type(type(other)) == DispatcherMeta
681+
682+
683+
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta):
677684
"""
678685
Implementation of user-facing dispatcher objects (i.e. created using
679686
the @jit decorator).

numba/core/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from numba.core.descriptors import TargetDescriptor
44
from numba.core import utils, typing, dispatcher, cpu
5+
from numba.core.compiler_lock import global_compiler_lock
56

67
# -----------------------------------------------------------------------------
78
# Default CPU target descriptors
@@ -26,16 +27,19 @@ class CPUTarget(TargetDescriptor):
2627
_nested = _NestedContext()
2728

2829
@utils.cached_property
30+
@global_compiler_lock
2931
def _toplevel_target_context(self):
3032
# Lazily-initialized top-level target context, for all threads
3133
return cpu.CPUContext(self.typing_context)
3234

3335
@utils.cached_property
36+
@global_compiler_lock
3437
def _toplevel_typing_context(self):
3538
# Lazily-initialized top-level typing context, for all threads
3639
return typing.Context()
3740

3841
@property
42+
@global_compiler_lock
3943
def target_context(self):
4044
"""
4145
The target context for CPU targets.
@@ -47,6 +51,7 @@ def target_context(self):
4751
return self._toplevel_target_context
4852

4953
@property
54+
@global_compiler_lock
5055
def typing_context(self):
5156
"""
5257
The typing context for CPU targets.
@@ -57,6 +62,7 @@ def typing_context(self):
5762
else:
5863
return self._toplevel_typing_context
5964

65+
@global_compiler_lock
6066
def nested_context(self, typing_context, target_context):
6167
"""
6268
A context manager temporarily replacing the contexts with the

numba/dppl/dppl_offload_dispatcher.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
from numba.core import dispatcher, compiler
2-
from numba.core.registry import cpu_target, dispatcher_registry
3-
import numba.dppl_config as dppl_config
4-
5-
6-
class DpplOffloadDispatcher(dispatcher.Dispatcher):
7-
targetdescr = cpu_target
8-
9-
def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler):
10-
if dppl_config.dppl_present:
11-
from numba.dppl.compiler import DPPLCompiler
12-
targetoptions['parallel'] = True
13-
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
14-
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler)
15-
else:
16-
print("---------------------------------------------------------------------")
17-
print("WARNING : DPPL pipeline ignored. Ensure OpenCL drivers are installed.")
18-
print("---------------------------------------------------------------------")
19-
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
20-
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class)
21-
22-
dispatcher_registry['__dppl_offload_gpu__'] = DpplOffloadDispatcher
1+
from numba.core import dispatcher, compiler
2+
from numba.core.registry import cpu_target, dispatcher_registry
3+
import numba.dppl_config as dppl_config
4+
5+
6+
class DpplOffloadDispatcher(dispatcher.Dispatcher):
7+
targetdescr = cpu_target
8+
9+
def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler):
10+
if dppl_config.dppl_present:
11+
from numba.dppl.compiler import DPPLCompiler
12+
targetoptions['parallel'] = True
13+
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
14+
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler)
15+
else:
16+
print("---------------------------------------------------------------------")
17+
print("WARNING : DPPL pipeline ignored. Ensure OpenCL drivers are installed.")
18+
print("---------------------------------------------------------------------")
19+
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
20+
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class)
21+
22+
dispatcher_registry['__dppl_offload_gpu__'] = DpplOffloadDispatcher
23+
dispatcher_registry['__dppl_offload_cpu__'] = DpplOffloadDispatcher
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
from numba import dppl, njit, prange
3+
import dpctl
4+
5+
6+
@njit
7+
def g(a):
8+
return a + 1
9+
10+
11+
@njit
12+
def f(a, b, c, N):
13+
for i in prange(N):
14+
a[i] = b[i] + g(c[i])
15+
16+
17+
def main():
18+
N = 10
19+
a = np.ones(N)
20+
b = np.ones(N)
21+
c = np.ones(N)
22+
23+
if dpctl.has_gpu_queues():
24+
with dpctl.device_context(dpctl.device_type.gpu):
25+
f(a, b, c, N)
26+
elif dpctl.has_cpu_queues():
27+
with dpctl.device_context(dpctl.device_type.cpu):
28+
f(a, b, c, N)
29+
else:
30+
print("No device found")
31+
32+
33+
if __name__ == '__main__':
34+
main()

numba/dppl/target_dispatcher.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from numba.core import registry, serialize, dispatcher
2+
from numba import types
3+
from numba.core.errors import UnsupportedError
4+
import dpctl
5+
import dpctl.ocldrv as ocldr
6+
from numba.core.compiler_lock import global_compiler_lock
7+
8+
9+
class TargetDispatcher(serialize.ReduceMixin, metaclass=dispatcher.DispatcherMeta):
10+
__numba__ = 'py_func'
11+
12+
target_offload_gpu = '__dppl_offload_gpu__'
13+
target_offload_cpu = '__dppl_offload_cpu__'
14+
target_dppl = 'dppl'
15+
16+
def __init__(self, py_func, wrapper, target, parallel_options, compiled=None):
17+
18+
self.__py_func = py_func
19+
self.__target = target
20+
self.__wrapper = wrapper
21+
self.__compiled = compiled if compiled is not None else {}
22+
self.__parallel = parallel_options
23+
self.__doc__ = py_func.__doc__
24+
self.__name__ = py_func.__name__
25+
self.__module__ = py_func.__module__
26+
27+
def __call__(self, *args, **kwargs):
28+
return self.get_compiled()(*args, **kwargs)
29+
30+
def __getattr__(self, name):
31+
return getattr(self.get_compiled(), name)
32+
33+
def __get__(self, obj, objtype=None):
34+
return self.get_compiled().__get__(obj, objtype)
35+
36+
def __repr__(self):
37+
return self.get_compiled().__repr__()
38+
39+
@classmethod
40+
def _rebuild(cls, py_func, wrapper, target, parallel, compiled):
41+
self = cls(py_func, wrapper, target, parallel, compiled)
42+
return self
43+
44+
def get_compiled(self, target=None):
45+
if target is None:
46+
target = self.__target
47+
48+
disp = self.get_current_disp()
49+
if not disp in self.__compiled.keys():
50+
with global_compiler_lock:
51+
if not disp in self.__compiled.keys():
52+
self.__compiled[disp] = self.__wrapper(self.__py_func, disp)
53+
54+
return self.__compiled[disp]
55+
56+
def __is_with_context_target(self, target):
57+
return target is None or target == TargetDispatcher.target_dppl
58+
59+
def get_current_disp(self):
60+
target = self.__target
61+
parallel = self.__parallel
62+
offload = isinstance(parallel, dict) and parallel.get('offload') is True
63+
64+
if (dpctl.is_in_device_context() or offload):
65+
if not self.__is_with_context_target(target):
66+
raise UnsupportedError(f"Can't use 'with' context with explicitly specified target '{target}'")
67+
if parallel is False or (isinstance(parallel, dict) and parallel.get('offload') is False):
68+
raise UnsupportedError(f"Can't use 'with' context with parallel option '{parallel}'")
69+
70+
from numba.dppl import dppl_offload_dispatcher
71+
72+
if target is None:
73+
if dpctl.get_current_device_type() == dpctl.device_type.gpu:
74+
return registry.dispatcher_registry[TargetDispatcher.target_offload_gpu]
75+
elif dpctl.get_current_device_type() == dpctl.device_type.cpu:
76+
return registry.dispatcher_registry[TargetDispatcher.target_offload_cpu]
77+
else:
78+
if dpctl.is_in_device_context():
79+
raise UnsupportedError('Unknown dppl device type')
80+
if offload:
81+
if dpctl.has_gpu_queues():
82+
return registry.dispatcher_registry[TargetDispatcher.target_offload_gpu]
83+
elif dpctl.has_cpu_queues():
84+
return registry.dispatcher_registry[TargetDispatcher.target_offload_cpu]
85+
86+
if target is None:
87+
target = 'cpu'
88+
89+
return registry.dispatcher_registry[target]
90+
91+
def _reduce_states(self):
92+
return dict(
93+
py_func=self.__py_func,
94+
wrapper=self.__wrapper,
95+
target=self.__target,
96+
parallel=self.__parallel,
97+
compiled=self.__compiled
98+
)

0 commit comments

Comments
 (0)