Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Numba patch for lower_extensions #215

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions numba/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def init(self):
import numba.typed.dictimpl
import numba.experimental.function_type

# Add lower_extension attribute
self.lower_extensions = {}
from numba.parfors.parfor_lowering import _lower_parfor_parallel
from numba.parfors.parfor import Parfor
# Specify how to lower Parfor nodes using the lower_extensions
self.lower_extensions[Parfor] = _lower_parfor_parallel

def load_additional_registries(self):
# Add target specific implementations
from numba.np import npyimpl
Expand Down
20 changes: 5 additions & 15 deletions numba/core/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,9 @@ def debug_print(self, msg):
self.context.debug_print(self.builder, "DEBUGJIT: {0}".format(msg))


# Dictionary mapping instruction class to its lowering function.
lower_extensions = {}


class Lower(BaseLower):
GeneratorLower = generators.GeneratorLower

def __init__(self, context, library, fndesc, func_ir, metadata=None):
BaseLower.__init__(self, context, library, fndesc, func_ir, metadata)
from numba.parfors.parfor_lowering import _lower_parfor_parallel
from numba.parfors import parfor
if parfor.Parfor not in lower_extensions:
lower_extensions[parfor.Parfor] = [_lower_parfor_parallel]

def pre_block(self, block):
from numba.core.unsafe import eh

Expand Down Expand Up @@ -445,10 +434,11 @@ def lower_inst(self, inst):
self.lower_static_try_raise(inst)

else:
for _class, func in lower_extensions.items():
if isinstance(inst, _class):
func[-1](self, inst)
return
if hasattr(self.context, "lower_extensions"):
for _class, func in self.context.lower_extensions.items():
if isinstance(inst, _class):
func(self, inst)
return
raise NotImplementedError(type(inst))

def lower_setitem(self, target_var, index_var, value_var, signature):
Expand Down
3 changes: 0 additions & 3 deletions numba/parfors/parfor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,6 @@ def _lower_parfor_parallel(lowerer, parfor):
if config.DEBUG_ARRAY_OPT:
print("_lower_parfor_parallel done")

# A work-around to prevent circular imports
#lowering.lower_extensions[parfor.Parfor] = _lower_parfor_parallel


def _create_shape_signature(
get_shape_classes,
Expand Down