Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 7ccc8e3

Browse files
committed
Added check to not error out if OpenCL is absent
1 parent 55032ef commit 7ccc8e3

File tree

5 files changed

+41
-17
lines changed

5 files changed

+41
-17
lines changed

numba/core/cpu_dispatcher.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from numba.core import dispatcher, compiler
22
from numba.core.registry import cpu_target, dispatcher_registry
3-
from numba.dppy.compiler import DPPyCompiler
43

54

65
class CPUDispatcher(dispatcher.Dispatcher):
@@ -9,8 +8,17 @@ class CPUDispatcher(dispatcher.Dispatcher):
98
def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler):
109
if ('parallel' in targetoptions and isinstance(targetoptions['parallel'], dict) and
1110
'spirv' in targetoptions['parallel'] and targetoptions['parallel']['spirv'] == True):
12-
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
13-
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPyCompiler)
11+
import numba.dppy_config as dppy_config
12+
if dppy_config.dppy_present:
13+
from numba.dppy.compiler import DPPyCompiler
14+
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
15+
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPyCompiler)
16+
else:
17+
print("-------------------------------------------------------------------------------------")
18+
print("WARNING : offload = True was set but since dppy is not present, we are not offloading")
19+
print("-------------------------------------------------------------------------------------")
20+
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
21+
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class)
1422
else:
1523
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
1624
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class)

numba/core/typing/context.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from numba.core import utils
1616

17-
import dppy.ocldrv as driver
18-
1917

2018
class Rating(object):
2119
__slots__ = 'promote', 'safe_convert', "unsafe_convert"
@@ -357,13 +355,17 @@ def resolve_argument_type(self, val):
357355
try:
358356
return typeof(val, Purpose.argument)
359357
except ValueError:
360-
if(type(val) == driver.DeviceArray):
361-
return typeof(val._ndarray, Purpose.argument)
362-
# DRD : Hmmm... is the assumption that this error is encountered
363-
# when someone is using cuda, and already has done an import
364-
# cuda?
365-
#elif numba.cuda.is_cuda_array(val):
366-
# return typeof(numba.cuda.as_cuda_array(val), Purpose.argument)
358+
from numba.dppy_config import dppy_present, DeviceArray
359+
if dppy_present:
360+
if(type(val) == DeviceArray):
361+
return typeof(val._ndarray, Purpose.argument)
362+
# DRD : Hmmm... is the assumption that this error is encountered
363+
# when someone is using cuda, and already has done an import
364+
# cuda?
365+
#elif numba.cuda.is_cuda_array(val):
366+
# return typeof(numba.cuda.as_cuda_array(val), Purpose.argument)
367+
else:
368+
raise
367369
else:
368370
raise
369371

numba/dppy/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55

66
from numba import config
77
import numba.testing
8-
from dppy.ocldrv import *
9-
from .device_init import *
8+
9+
from numba.dppy_config import *
10+
if dppy_present:
11+
from .device_init import *
12+
else:
13+
raise ImportError("Importing dppy failed")
1014

1115
def test(*args, **kwargs):
12-
if not is_available():
16+
if dppy_present and not is_available():
1317
dppy_error()
1418

1519
return numba.testing.test("numba.dppy.tests", *args, **kwargs)

numba/dppy/tests/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from numba.testing import SerialSuite
22
from numba.testing import load_testsuite
33
from os.path import dirname, join
4-
import dppy.ocldrv as ocldrv
4+
5+
6+
import numba.dppy_config as dppy_config
57

68
def load_tests(loader, tests, pattern):
79

810
suite = SerialSuite()
911
this_dir = dirname(__file__)
1012

11-
if ocldrv.is_available():
13+
if dppy_config.dppy_present and dppy_config.is_available():
1214
suite.addTests(load_testsuite(loader, join(this_dir, 'dppy')))
1315
else:
1416
print("skipped DPPY tests")

numba/dppy_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
dppy_present = False
2+
3+
try:
4+
from dppy.ocldrv import *
5+
except:
6+
pass
7+
else:
8+
dppy_present = True

0 commit comments

Comments
 (0)