Skip to content

Commit fb9c310

Browse files
(numba/dppl) Fix ir copy in fallback (#61)
1 parent 1492a13 commit fb9c310

File tree

3 files changed

+195
-58
lines changed

3 files changed

+195
-58
lines changed

dppl_lowerer.py

Lines changed: 105 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -699,21 +699,22 @@ def _lower_parfor_gufunc(lowerer, parfor):
699699
numba.parfors.parfor.sequential_parfor_lowering = True
700700
loop_ranges = [(l.start, l.stop, l.step) for l in parfor.loop_nests]
701701

702-
func, func_args, func_sig, func_arg_types, modified_arrays =(
703-
_create_gufunc_for_parfor_body(
704-
lowerer,
705-
parfor,
706-
typemap,
707-
typingctx,
708-
targetctx,
709-
flags,
710-
loop_ranges,
711-
{},
712-
bool(alias_map),
713-
index_var_typ,
714-
parfor.races))
715-
716-
numba.parfors.parfor.sequential_parfor_lowering = False
702+
try:
703+
func, func_args, func_sig, func_arg_types, modified_arrays =(
704+
_create_gufunc_for_parfor_body(
705+
lowerer,
706+
parfor,
707+
typemap,
708+
typingctx,
709+
targetctx,
710+
flags,
711+
loop_ranges,
712+
{},
713+
bool(alias_map),
714+
index_var_typ,
715+
parfor.races))
716+
finally:
717+
numba.parfors.parfor.sequential_parfor_lowering = False
717718

718719
# get the shape signature
719720
get_shape_classes = parfor.get_shape_classes
@@ -956,10 +957,14 @@ def load_range(v):
956957
from numba.core.lowering import Lower
957958

958959

960+
class CopyIRException(RuntimeError):
961+
def __init__(self, *args, **kwargs):
962+
super().__init__(*args, **kwargs)
963+
964+
959965
class DPPLLower(Lower):
960966
def __init__(self, context, library, fndesc, func_ir, metadata=None):
961967
Lower.__init__(self, context, library, fndesc, func_ir, metadata)
962-
963968
fndesc_cpu = copy.copy(fndesc)
964969
fndesc_cpu.calltypes = fndesc.calltypes.copy()
965970
fndesc_cpu.typemap = fndesc.typemap.copy()
@@ -984,17 +989,23 @@ def lower(self):
984989
# WARNING: this approach only works in case no device specific modifications were added to
985990
# parent function (function with parfor). In case parent function was patched with device specific
986991
# different solution should be used.
992+
987993
try:
988994
#lowering.lower_extensions[parfor.Parfor] = lower_parfor_rollback
989995
lowering.lower_extensions[parfor.Parfor].append(lower_parfor_rollback)
990996
self.gpu_lower.lower()
991997
self.base_lower = self.gpu_lower
992998
#lowering.lower_extensions[parfor.Parfor] = numba.parfors.parfor_lowering._lower_parfor_parallel
993999
lowering.lower_extensions[parfor.Parfor].pop()
994-
except:
995-
lowering.lower_extensions[parfor.Parfor].append(numba.parfors.parfor_lowering._lower_parfor_parallel)
996-
self.cpu_lower.lower()
997-
self.base_lower = self.cpu_lower
1000+
except Exception as e:
1001+
if numba.dppl.compiler.DEBUG:
1002+
print("Failed to lower parfor on DPPL-device. Due to:\n", e)
1003+
lowering.lower_extensions[parfor.Parfor].pop()
1004+
if (lowering.lower_extensions[parfor.Parfor][-1] == numba.parfors.parfor_lowering._lower_parfor_parallel):
1005+
self.cpu_lower.lower()
1006+
self.base_lower = self.cpu_lower
1007+
else:
1008+
raise e
9981009

9991010
self.env = self.base_lower.env
10001011
self.call_helper = self.base_lower.call_helper
@@ -1003,20 +1014,88 @@ def create_cpython_wrapper(self, release_gil=False):
10031014
return self.base_lower.create_cpython_wrapper(release_gil)
10041015

10051016

1017+
def copy_block(block):
1018+
def relatively_deep_copy(obj, memo):
1019+
obj_id = id(obj)
1020+
if obj_id in memo:
1021+
return memo[obj_id]
1022+
1023+
from numba.core.dispatcher import Dispatcher
1024+
from numba.core.types.functions import Function
1025+
from types import ModuleType
1026+
1027+
if isinstance(obj, (Dispatcher, Function, ModuleType)):
1028+
return obj
1029+
1030+
if isinstance(obj, list):
1031+
cpy = copy.copy(obj)
1032+
cpy.clear()
1033+
for item in obj:
1034+
cpy.append(relatively_deep_copy(item, memo))
1035+
memo[obj_id] = cpy
1036+
return cpy
1037+
elif isinstance(obj, dict):
1038+
cpy = copy.copy(obj)
1039+
cpy.clear()
1040+
# do we need to copy keys?
1041+
for key, item in obj.items():
1042+
cpy[relatively_deep_copy(key, memo)] = relatively_deep_copy(item, memo)
1043+
memo[obj_id] = cpy
1044+
return cpy
1045+
elif isinstance(obj, tuple):
1046+
cpy = type(obj)([relatively_deep_copy(item, memo) for item in obj])
1047+
memo[obj_id] = cpy
1048+
return cpy
1049+
elif isinstance(obj, set):
1050+
cpy = copy.copy(obj)
1051+
cpy.clear()
1052+
for item in obj:
1053+
cpy.add(relatively_deep_copy(item, memo))
1054+
memo[obj_id] = cpy
1055+
return cpy
1056+
1057+
cpy = copy.copy(obj)
1058+
1059+
memo[obj_id] = cpy
1060+
keys = []
1061+
try:
1062+
keys = obj.__dict__.keys()
1063+
except:
1064+
try:
1065+
keys = obj.__slots__
1066+
except:
1067+
return cpy
1068+
1069+
for key in keys:
1070+
attr = getattr(obj, key)
1071+
attr_cpy = relatively_deep_copy(attr, memo)
1072+
setattr(cpy, key, attr_cpy)
1073+
1074+
return cpy
1075+
1076+
memo = {}
1077+
new_block = ir.Block(block.scope, block.loc)
1078+
new_block.body = [relatively_deep_copy(stmt, memo) for stmt in block.body]
1079+
return new_block
1080+
1081+
10061082
def lower_parfor_rollback(lowerer, parfor):
1007-
cache_parfor_races = copy.copy(parfor.races)
1008-
cache_parfor_params = copy.copy(parfor.params)
1009-
cache_parfor_loop_body = copy.deepcopy(parfor.loop_body)
1010-
cache_parfor_init_block = parfor.init_block.copy()
1011-
cache_parfor_loop_nests = parfor.loop_nests.copy()
1083+
try:
1084+
cache_parfor_races = copy.copy(parfor.races)
1085+
cache_parfor_params = copy.copy(parfor.params)
1086+
cache_parfor_loop_body = {key: copy_block(block) for key, block in parfor.loop_body.items()}
1087+
cache_parfor_init_block = parfor.init_block.copy()
1088+
cache_parfor_loop_nests = parfor.loop_nests.copy()
1089+
except Exception as e:
1090+
raise CopyIRException("Failed to copy IR") from e
10121091

10131092
try:
10141093
_lower_parfor_gufunc(lowerer, parfor)
10151094
if numba.dppl.compiler.DEBUG:
10161095
msg = "Parfor lowered on DPPL-device"
10171096
print(msg, parfor.loc)
10181097
except Exception as e:
1019-
msg = "Failed to lower parfor on DPPL-device"
1098+
msg = "Failed to lower parfor on DPPL-device.\nTo see details set environment variable NUMBA_DEBUG=1"
10201099
warnings.warn(NumbaPerformanceWarning(msg, parfor.loc))
10211100
raise e
10221101
finally:
Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,34 @@
1-
import numpy as np
2-
import numba
3-
from numba import dppl, njit, prange
4-
from numba.dppl.testing import unittest, DPPLTestCase
5-
from numba.tests.support import captured_stdout
6-
import dpctl
7-
8-
9-
def prange_example():
10-
n = 10
11-
a = np.ones((n, n), dtype=np.float64)
12-
b = np.ones((n, n), dtype=np.float64)
13-
c = np.ones((n, n), dtype=np.float64)
14-
for i in prange(n):
15-
a[i] = b[i] + c[i]
16-
17-
18-
@unittest.skipUnless(dpctl.has_gpu_queues(), 'test only on GPU system')
19-
class TestParforMessage(DPPLTestCase):
20-
def test_parfor_message(self):
21-
numba.dppl.compiler.DEBUG = 1
22-
jitted = njit(parallel={'offload':True})(prange_example)
23-
24-
with captured_stdout() as got:
25-
with dpctl.device_context("opencl:gpu") as gpu_queue:
26-
jitted()
27-
28-
self.assertTrue('Parfor lowered on DPPL-device' in got.getvalue())
29-
30-
31-
if __name__ == '__main__':
32-
unittest.main()
1+
import numpy as np
2+
import numba
3+
from numba import dppl, njit, prange
4+
from numba.dppl.testing import unittest, DPPLTestCase
5+
from numba.tests.support import captured_stdout
6+
import dpctl.ocldrv as ocldrv
7+
8+
9+
def prange_example():
10+
n = 10
11+
a = np.ones((n), dtype=np.float64)
12+
b = np.ones((n), dtype=np.float64)
13+
c = np.ones((n), dtype=np.float64)
14+
for i in prange(n//2):
15+
a[i] = b[i] + c[i]
16+
17+
return a
18+
19+
20+
@unittest.skipUnless(ocldrv.has_gpu_device, 'test only on GPU system')
21+
class TestParforMessage(DPPLTestCase):
22+
def test_parfor_message(self):
23+
numba.dppl.compiler.DEBUG = 1
24+
jitted = njit(parallel={'offload':True})(prange_example)
25+
26+
with captured_stdout() as got:
27+
jitted()
28+
29+
numba.dppl.compiler.DEBUG = 0
30+
self.assertTrue('Parfor lowered on DPPL-device' in got.getvalue())
31+
32+
33+
if __name__ == '__main__':
34+
unittest.main()

tests/dppl/test_prange.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import sys
66
import numpy as np
7+
import numba
78
from numba import dppl, njit, prange
89
from numba.dppl.testing import unittest
910
from numba.dppl.testing import DPPLTestCase
11+
from numba.tests.support import captured_stdout
1012

1113

1214
class TestPrange(DPPLTestCase):
@@ -92,6 +94,60 @@ def f(a, b):
9294
self.assertTrue(np.all(b == 12))
9395

9496

97+
def test_two_consequent_prange(self):
98+
def prange_example():
99+
n = 10
100+
a = np.ones((n), dtype=np.float64)
101+
b = np.ones((n), dtype=np.float64)
102+
c = np.ones((n), dtype=np.float64)
103+
for i in prange(n//2):
104+
a[i] = b[i] + c[i]
105+
106+
return a
107+
108+
old_debug = numba.dppl.compiler.DEBUG
109+
numba.dppl.compiler.DEBUG = 1
110+
111+
jitted = njit(parallel={'offload':True})(prange_example)
112+
with captured_stdout() as got:
113+
jitted_res = jitted()
114+
115+
res = prange_example()
116+
117+
numba.dppl.compiler.DEBUG = old_debug
118+
119+
self.assertEqual(got.getvalue().count('Parfor lowered on DPPL-device'), 2)
120+
self.assertEqual(got.getvalue().count('Failed to lower parfor on DPPL-device'), 0)
121+
np.testing.assert_equal(res, jitted_res)
122+
123+
124+
@unittest.skip('NRT required but not enabled')
125+
def test_2d_arrays(self):
126+
def prange_example():
127+
n = 10
128+
a = np.ones((n, n), dtype=np.float64)
129+
b = np.ones((n, n), dtype=np.float64)
130+
c = np.ones((n, n), dtype=np.float64)
131+
for i in prange(n//2):
132+
a[i] = b[i] + c[i]
133+
134+
return a
135+
136+
old_debug = numba.dppl.compiler.DEBUG
137+
numba.dppl.compiler.DEBUG = 1
138+
139+
jitted = njit(parallel={'offload':True})(prange_example)
140+
with captured_stdout() as got:
141+
jitted_res = jitted()
142+
143+
res = prange_example()
144+
145+
numba.dppl.compiler.DEBUG = old_debug
146+
147+
self.assertEqual(got.getvalue().count('Parfor lowered on DPPL-device'), 2)
148+
self.assertEqual(got.getvalue().count('Failed to lower parfor on DPPL-device'), 0)
149+
np.testing.assert_equal(res, jitted_res)
150+
95151

96152
if __name__ == '__main__':
97153
unittest.main()

0 commit comments

Comments
 (0)