Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Patch for with context #96

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions numba/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def bar(x, y):
target = options.pop('target')
warnings.warn("The 'target' keyword argument is deprecated.", NumbaDeprecationWarning)
else:
target = options.pop('_target', None)
target = options.pop('_target', 'cpu')

options['boundscheck'] = boundscheck

Expand Down Expand Up @@ -183,16 +183,27 @@ def bar(x, y):


def _jit(sigs, locals, target, cache, targetoptions, **dispatcher_args):
dispatcher = registry.dispatcher_registry[target]

def wrapper(func):
if extending.is_jitted(func):
raise TypeError(
"A jit decorator was called on an already jitted function "
f"{func}. If trying to access the original python "
f"function, use the {func}.py_func attribute."
)

if not inspect.isfunction(func):
raise TypeError(
"The decorated object is not a function (got type "
f"{type(func)})."
)

def wrapper(func, dispatcher):
if config.ENABLE_CUDASIM and target == 'cuda':
from numba import cuda
return cuda.jit(func)
if config.DISABLE_JIT and not target == 'npyufunc':
return func
if target == 'dppl':
from . import dppl
return dppl.jit(func)
disp = dispatcher(py_func=func, locals=locals,
targetoptions=targetoptions,
**dispatcher_args)
Expand All @@ -208,35 +219,7 @@ def wrapper(func, dispatcher):
disp.disable_compile()
return disp

def __wrapper(func):
if extending.is_jitted(func):
raise TypeError(
"A jit decorator was called on an already jitted function "
f"{func}. If trying to access the original python "
f"function, use the {func}.py_func attribute."
)

if not inspect.isfunction(func):
raise TypeError(
"The decorated object is not a function (got type "
f"{type(func)})."
)

from numba import dppl_config
if (target == 'npyufunc' or targetoptions.get('no_cpython_wrapper')
or sigs or config.DISABLE_JIT or not targetoptions.get('nopython')
or dppl_config.dppl_present is not True):
target_ = target
if target_ is None:
target_ = 'cpu'
disp = registry.dispatcher_registry[target_]
return wrapper(func, disp)

from numba.dppl.target_dispatcher import TargetDispatcher
disp = TargetDispatcher(func, wrapper, target, targetoptions.get('parallel'))
return disp

return __wrapper
return wrapper


def generated_jit(function=None, target='cpu', cache=False,
Expand Down
12 changes: 1 addition & 11 deletions numba/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,14 +673,7 @@ def _set_uuid(self, u):
self._recent.append(self)


import abc

class DispatcherMeta(abc.ABCMeta):
def __instancecheck__(self, other):
return type(type(other)) == DispatcherMeta


class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase, metaclass=DispatcherMeta):
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
"""
Implementation of user-facing dispatcher objects (i.e. created using
the @jit decorator).
Expand Down Expand Up @@ -906,9 +899,6 @@ def get_function_type(self):
cres = tuple(self.overloads.values())[0]
return types.FunctionType(cres.signature)

def get_compiled(self):
return self


class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
"""
Expand Down
6 changes: 0 additions & 6 deletions numba/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from numba.core.descriptors import TargetDescriptor
from numba.core import utils, typing, dispatcher, cpu
from numba.core.compiler_lock import global_compiler_lock

# -----------------------------------------------------------------------------
# Default CPU target descriptors
Expand All @@ -27,19 +26,16 @@ class CPUTarget(TargetDescriptor):
_nested = _NestedContext()

@utils.cached_property
@global_compiler_lock
def _toplevel_target_context(self):
# Lazily-initialized top-level target context, for all threads
return cpu.CPUContext(self.typing_context)

@utils.cached_property
@global_compiler_lock
def _toplevel_typing_context(self):
# Lazily-initialized top-level typing context, for all threads
return typing.Context()

@property
@global_compiler_lock
def target_context(self):
"""
The target context for CPU targets.
Expand All @@ -51,7 +47,6 @@ def target_context(self):
return self._toplevel_target_context

@property
@global_compiler_lock
def typing_context(self):
"""
The typing context for CPU targets.
Expand All @@ -62,7 +57,6 @@ def typing_context(self):
else:
return self._toplevel_typing_context

@global_compiler_lock
def nested_context(self, typing_context, target_context):
"""
A context manager temporarily replacing the contexts with the
Expand Down
2 changes: 0 additions & 2 deletions numba/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,6 @@ def test_serialization(self):
def foo(x):
return x + 1

foo = foo.get_compiled()

self.assertEqual(foo(1), 2)

# get serialization memo
Expand Down
2 changes: 0 additions & 2 deletions numba/tests/test_nrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ def alloc_nrt_memory():
"""
return np.empty(N, dtype)

alloc_nrt_memory = alloc_nrt_memory.get_compiled()

def keep_memory():
return alloc_nrt_memory()

Expand Down
4 changes: 2 additions & 2 deletions numba/tests/test_record_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,8 @@ def test_record_arg_transform(self):
self.assertIn('Array', transformed)
self.assertNotIn('first', transformed)
self.assertNotIn('second', transformed)
# Length is usually 60 - 5 chars tolerance as above.
self.assertLess(len(transformed), 60)
# Length is usually 50 - 5 chars tolerance as above.
self.assertLess(len(transformed), 50)

def test_record_two_arrays(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions numba/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_reuse(self):

Note that "same function" is intentionally under-specified.
"""
func = closure(5).get_compiled()
func = closure(5)
pickled = pickle.dumps(func)
func2 = closure(6).get_compiled()
func2 = closure(6)
pickled2 = pickle.dumps(func2)

f = pickle.loads(pickled)
Expand All @@ -152,7 +152,7 @@ def test_reuse(self):
self.assertEqual(h(2, 3), 11)

# Now make sure the original object doesn't exist when deserializing
func = closure(7).get_compiled()
func = closure(7)
func(42, 43)
pickled = pickle.dumps(func)
del func
Expand Down