@@ -97,7 +97,30 @@ def init(self):
97
97
.SPIR_DATA_LAYOUT [utils .MACHINE_BITS ]))
98
98
# Override data model manager to SPIR model
99
99
self .data_model_manager = spirv_data_model_manager
100
- self .done_once = False
100
+
101
+ from numba .np .ufunc_db import _ufunc_db as ufunc_db , _lazy_init_db
102
+ import copy
103
+ _lazy_init_db ()
104
+ self .ufunc_db = copy .deepcopy (ufunc_db )
105
+
106
+
107
+ def replace_numpy_ufunc_with_opencl_supported_functions (self ):
108
+ from numba .dppy .ocl .mathimpl import lower_ocl_impl , sig_mapper
109
+
110
+ ufuncs = [("fabs" , np .fabs ), ("exp" , np .exp ), ("log" , np .log ),
111
+ ("log10" , np .log10 ), ("expm1" , np .expm1 ), ("log1p" , np .log1p ),
112
+ ("sqrt" , np .sqrt ), ("sin" , np .sin ), ("cos" , np .cos ),
113
+ ("tan" , np .tan ), ("asin" , np .arcsin ), ("acos" , np .arccos ),
114
+ ("atan" , np .arctan ), ("atan2" , np .arctan2 ), ("sinh" , np .sinh ),
115
+ ("cosh" , np .cosh ), ("tanh" , np .tanh ), ("asinh" , np .arcsinh ),
116
+ ("acosh" , np .arccosh ), ("atanh" , np .arctanh ), ("ldexp" , np .ldexp ),
117
+ ("floor" , np .floor ), ("ceil" , np .ceil ), ("trunc" , np .trunc )]
118
+
119
+ for name , ufunc in ufuncs :
120
+ for sig in self .ufunc_db [ufunc ].keys ():
121
+ if sig in sig_mapper and (name , sig_mapper [sig ]) in lower_ocl_impl :
122
+ self .ufunc_db [ufunc ][sig ] = lower_ocl_impl [(name , sig_mapper [sig ])]
123
+
101
124
102
125
def load_additional_registries (self ):
103
126
from .ocl import oclimpl , mathimpl
@@ -111,9 +134,7 @@ def load_additional_registries(self):
111
134
functions we will redirect some of NUMBA's NumPy
112
135
ufunc with OpenCL's.
113
136
"""
114
- if not self .done_once :
115
- _replace_numpy_ufunc_with_opencl_supported_functions ()
116
- self .done_once = True
137
+ self .replace_numpy_ufunc_with_opencl_supported_functions ()
117
138
118
139
119
140
@cached_property
0 commit comments