Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit a77eab6

Browse files
committed
DPPy target will retain a copy of ufunc_db for replacing the implementation of numpy ufuncs with that of OpenCL's when such function exists
1 parent 0669d98 commit a77eab6

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

numba/dppy/target.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,30 @@ def init(self):
9797
.SPIR_DATA_LAYOUT[utils.MACHINE_BITS]))
9898
# Override data model manager to SPIR model
9999
self.data_model_manager = spirv_data_model_manager
100-
self.done_once = False
100+
101+
from numba.np.ufunc_db import _ufunc_db as ufunc_db, _lazy_init_db
102+
import copy
103+
_lazy_init_db()
104+
self.ufunc_db = copy.deepcopy(ufunc_db)
105+
106+
107+
def replace_numpy_ufunc_with_opencl_supported_functions(self):
108+
from numba.dppy.ocl.mathimpl import lower_ocl_impl, sig_mapper
109+
110+
ufuncs = [("fabs", np.fabs), ("exp", np.exp), ("log", np.log),
111+
("log10", np.log10), ("expm1", np.expm1), ("log1p", np.log1p),
112+
("sqrt", np.sqrt), ("sin", np.sin), ("cos", np.cos),
113+
("tan", np.tan), ("asin", np.arcsin), ("acos", np.arccos),
114+
("atan", np.arctan), ("atan2", np.arctan2), ("sinh", np.sinh),
115+
("cosh", np.cosh), ("tanh", np.tanh), ("asinh", np.arcsinh),
116+
("acosh", np.arccosh), ("atanh", np.arctanh), ("ldexp", np.ldexp),
117+
("floor", np.floor), ("ceil", np.ceil), ("trunc", np.trunc)]
118+
119+
for name, ufunc in ufuncs:
120+
for sig in self.ufunc_db[ufunc].keys():
121+
if sig in sig_mapper and (name, sig_mapper[sig]) in lower_ocl_impl:
122+
self.ufunc_db[ufunc][sig] = lower_ocl_impl[(name, sig_mapper[sig])]
123+
101124

102125
def load_additional_registries(self):
103126
from .ocl import oclimpl, mathimpl
@@ -111,9 +134,7 @@ def load_additional_registries(self):
111134
functions we will redirect some of NUMBA's NumPy
112135
ufunc with OpenCL's.
113136
"""
114-
if not self.done_once:
115-
_replace_numpy_ufunc_with_opencl_supported_functions()
116-
self.done_once = True
137+
self.replace_numpy_ufunc_with_opencl_supported_functions()
117138

118139

119140
@cached_property

numba/np/npyimpl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,12 @@ def __init__(self, context, builder, outer_sig):
412412
super(_KernelImpl, self).__init__(context, builder, outer_sig)
413413
loop = ufunc_find_matching_loop(
414414
ufunc, outer_sig.args + (outer_sig.return_type,))
415-
self.fn = ufunc_db.get_ufunc_info(ufunc).get(loop.ufunc_sig)
415+
416+
if hasattr(context, 'ufunc_db'):
417+
self.fn = context.ufunc_db[ufunc].get(loop.ufunc_sig)
418+
else:
419+
self.fn = ufunc_db.get_ufunc_info(ufunc).get(loop.ufunc_sig)
420+
416421
self.inner_sig = typing.signature(
417422
*(loop.outputs + loop.inputs))
418423

0 commit comments

Comments
 (0)