Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit f7d6207

Browse files
committed
Numba patch for lower_extensions
1 parent 7d26980 commit f7d6207

File tree

4 files changed

+15
-18
lines changed

4 files changed

+15
-18
lines changed

numba/core/cpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def init(self):
6262
import numba.typed.dictimpl
6363
import numba.experimental.function_type
6464

65+
# Add lower_extension attribute
66+
self.lower_extensions = {}
67+
6568
def load_additional_registries(self):
6669
# Add target specific implementations
6770
from numba.np import npyimpl

numba/core/lowering.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -267,20 +267,9 @@ def debug_print(self, msg):
267267
self.context.debug_print(self.builder, "DEBUGJIT: {0}".format(msg))
268268

269269

270-
# Dictionary mapping instruction class to its lowering function.
271-
lower_extensions = {}
272-
273-
274270
class Lower(BaseLower):
275271
GeneratorLower = generators.GeneratorLower
276272

277-
def __init__(self, context, library, fndesc, func_ir, metadata=None):
278-
BaseLower.__init__(self, context, library, fndesc, func_ir, metadata)
279-
from numba.parfors.parfor_lowering import _lower_parfor_parallel
280-
from numba.parfors import parfor
281-
if parfor.Parfor not in lower_extensions:
282-
lower_extensions[parfor.Parfor] = [_lower_parfor_parallel]
283-
284273
def pre_block(self, block):
285274
from numba.core.unsafe import eh
286275

@@ -445,10 +434,11 @@ def lower_inst(self, inst):
445434
self.lower_static_try_raise(inst)
446435

447436
else:
448-
for _class, func in lower_extensions.items():
449-
if isinstance(inst, _class):
450-
func[-1](self, inst)
451-
return
437+
if hasattr(self.context, "lower_extensions"):
438+
for _class, func in self.context.lower_extensions.items():
439+
if isinstance(inst, _class):
440+
func(self, inst)
441+
return
452442
raise NotImplementedError(type(inst))
453443

454444
def lower_setitem(self, target_var, index_var, value_var, signature):

numba/core/typed_passes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,13 @@ def run_pass(self, state):
277277
"""
278278
Convert data-parallel computations into Parfor nodes
279279
"""
280+
# Register lowerer for Parfor Node
281+
from numba.parfors.parfor_lowering import _lower_parfor_parallel
282+
if hasattr(state.targetctx, "lower_extensions"):
283+
state.targetctx.lower_extensions[Parfor] = _lower_parfor_parallel
284+
else:
285+
raise AttributeError("target_context has no attribute 'lower_extensions'")
286+
280287
# Ensure we have an IR and type information.
281288
assert state.func_ir
282289
parfor_pass = _parfor_ParforPass(state.func_ir,

numba/parfors/parfor_lowering.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,6 @@ def _lower_parfor_parallel(lowerer, parfor):
479479
if config.DEBUG_ARRAY_OPT:
480480
print("_lower_parfor_parallel done")
481481

482-
# A work-around to prevent circular imports
483-
#lowering.lower_extensions[parfor.Parfor] = _lower_parfor_parallel
484-
485482

486483
def _create_shape_signature(
487484
get_shape_classes,

0 commit comments

Comments
 (0)