Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit cb41cf4

Browse files
committed
Default value to None
1 parent f7d6207 commit cb41cf4

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

numba/parfors/parfor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def linspace_3(start, stop, num):
411411
else:
412412
raise ValueError("parallel linspace with types {}".format(args))
413413

414-
replace_functions_map = {
414+
swap_functions_map = {
415415
('argmin', 'numpy'): lambda r,a: argmin_parallel_impl,
416416
('argmax', 'numpy'): lambda r,a: argmax_parallel_impl,
417417
('min', 'numpy'): min_parallel_impl,
@@ -1387,14 +1387,16 @@ class PreParforPass(object):
13871387
implementations of numpy functions if available.
13881388
"""
13891389
def __init__(self, func_ir, typemap, calltypes, typingctx, options,
1390-
swapped={}, replace_functions_map=replace_functions_map):
1390+
swapped={}, replace_functions_map=None):
13911391
self.func_ir = func_ir
13921392
self.typemap = typemap
13931393
self.calltypes = calltypes
13941394
self.typingctx = typingctx
13951395
self.options = options
13961396
# diagnostics
13971397
self.swapped = swapped
1398+
if replace_functions_map is None:
1399+
replace_functions_map = swap_functions_map
13981400
self.replace_functions_map = replace_functions_map
13991401
self.stats = {
14001402
'replaced_func': 0,

numba/tests/test_parfors_passes.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, typingctx, targetctx, args, test_ir):
3838

3939
class BaseTest(TestCase):
4040
@classmethod
41-
def _run_parfor(cls, test_func, args):
41+
def _run_parfor(cls, test_func, args, swap_map=None):
4242
# TODO: refactor this with get_optimized_numba_ir() where this is
4343
# copied from
4444
typingctx = typing.Context()
@@ -80,6 +80,7 @@ def _run_parfor(cls, test_func, args):
8080
tp.state.typingctx,
8181
options,
8282
swapped=diagnostics.replaced_fns,
83+
replace_functions_map=swap_map,
8384
)
8485
preparfor_pass.run()
8586

@@ -109,9 +110,9 @@ def run_parfor_sub_pass(cls, test_func, args):
109110
return sub_pass
110111

111112
@classmethod
112-
def run_parfor_pre_pass(cls, test_func, args):
113+
def run_parfor_pre_pass(cls, test_func, args, swap_map=None):
113114
tp, options, diagnostics, preparfor_pass = cls._run_parfor(
114-
test_func, args
115+
test_func, args, swap_map
115116
)
116117
return preparfor_pass
117118

@@ -651,6 +652,20 @@ def test_impl(a):
651652
self.assertEqual(pre_pass.stats["replaced_dtype"], 0)
652653
self.run_parallel(test_impl, *args)
653654

655+
def test_replacement_map(self):
656+
def test_impl(a):
657+
return np.sum(a)
658+
659+
arr = np.arange(10)
660+
args = (arr,)
661+
argtypes = [typeof(x) for x in args]
662+
663+
swap_map = numba.parfors.parfor.swap_functions_map.copy()
664+
swap_map.pop(("sum", "numpy"))
665+
pre_pass = self.run_parfor_pre_pass(test_impl, argtypes, swap_map)
666+
self.assertEqual(pre_pass.stats["replaced_func"], 0)
667+
self.assertEqual(pre_pass.stats["replaced_dtype"], 0)
668+
self.run_parallel(test_impl, *args)
654669

655670
if __name__ == "__main__":
656671
unittest.main()

0 commit comments

Comments
 (0)