Skip to content

Commit 023fef9

Browse files
reazulhoqueDrTodd13
authored andcommitted
Pass to rewrite Numpy function names to be able to overload them for Numba-dppy pipeline (#52)
* Sum example * Moved from infer_type, lower_builtin to overload * Added two level module name functions * Remove cython generated file * Module name fix for moving to new extension * Incomplete linalg.eig implementation * Updted all dppl to dppy and moved rewrite_numpy_function_pass to it's own file * Import module at correct locations * Added comments * Added test and updated comments * Revert unneeded changes * Update Eigen implementation * Remove eig implementation * Add checking equivalent IR Co-authored-by: reazul.hoque <reazul.hoque@intel.com>
1 parent 0a56e08 commit 023fef9

11 files changed

+389
-8
lines changed

numba_dppy/device_init.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
CLK_GLOBAL_MEM_FENCE,
1919
)
2020

21+
"""
22+
We are importing dpnp stub module to make Numba recognize the
23+
module when we rename Numpy functions.
24+
"""
25+
from .dpnp_glue.stubs import (
26+
dpnp
27+
)
28+
2129
DEFAULT_LOCAL_SIZE = []
2230

2331
from . import initialize
@@ -35,9 +43,4 @@ def is_available():
3543
return dpctl.has_gpu_queues()
3644

3745

38-
#def ocl_error():
39-
# """Returns None or an exception if the OpenCL driver fails to initialize.
40-
# """
41-
# return driver.driver.initialization_error
42-
4346
initialize.initialize_all()

numba_dppy/dpctl_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from numba import types
2+
from numba.core.typing import signature
3+
4+
5+
class _DPCTL_FUNCTIONS:
6+
@classmethod
7+
def dpctl_get_current_queue(cls):
8+
ret_type = types.voidptr
9+
sig = signature(ret_type)
10+
return types.ExternalFunction("DPCTLQueueMgr_GetCurrentQueue", sig)
11+
12+
@classmethod
13+
def dpctl_malloc_shared(cls):
14+
ret_type = types.voidptr
15+
sig = signature(ret_type, types.int64, types.voidptr)
16+
return types.ExternalFunction("DPCTLmalloc_shared", sig)
17+
18+
@classmethod
19+
def dpctl_queue_memcpy(cls):
20+
ret_type = types.void
21+
sig = signature(
22+
ret_type, types.voidptr, types.voidptr, types.voidptr, types.int64
23+
)
24+
return types.ExternalFunction("DPCTLQueue_Memcpy", sig)
25+
26+
@classmethod
27+
def dpctl_free_with_queue(cls):
28+
ret_type = types.void
29+
sig = signature(ret_type, types.voidptr, types.voidptr)
30+
return types.ExternalFunction("DPCTLfree_with_queue", sig)

numba_dppy/dpnp_glue/__init__.py

Whitespace-only changes.

numba_dppy/dpnp_glue/dpnp_fptr_interface.pyx

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cdef extern from "backend_iface_fptr.hpp" namespace "DPNPFuncName": # need this
88
cdef enum DPNPFuncName "DPNPFuncName":
99
DPNP_FN_ABSOLUTE
1010
DPNP_FN_ADD
11+
DPNP_FN_ARANGE
1112
DPNP_FN_ARCCOS
1213
DPNP_FN_ARCCOSH
1314
DPNP_FN_ARCSIN
@@ -18,40 +19,77 @@ cdef extern from "backend_iface_fptr.hpp" namespace "DPNPFuncName": # need this
1819
DPNP_FN_ARGMAX
1920
DPNP_FN_ARGMIN
2021
DPNP_FN_ARGSORT
22+
DPNP_FN_BITWISE_AND
23+
DPNP_FN_BITWISE_OR
24+
DPNP_FN_BITWISE_XOR
2125
DPNP_FN_CBRT
2226
DPNP_FN_CEIL
27+
DPNP_FN_CHOLESKY
28+
DPNP_FN_COPYSIGN
29+
DPNP_FN_CORRELATE
2330
DPNP_FN_COS
2431
DPNP_FN_COSH
2532
DPNP_FN_COV
2633
DPNP_FN_DEGREES
34+
DPNP_FN_DET
2735
DPNP_FN_DIVIDE
2836
DPNP_FN_DOT
2937
DPNP_FN_EIG
38+
DPNP_FN_EIGVALS
3039
DPNP_FN_EXP
3140
DPNP_FN_EXP2
3241
DPNP_FN_EXPM1
3342
DPNP_FN_FABS
43+
DPNP_FN_FFT_FFT
3444
DPNP_FN_FLOOR
45+
DPNP_FN_FLOOR_DIVIDE
3546
DPNP_FN_FMOD
36-
DPNP_FN_GAUSSIAN
3747
DPNP_FN_HYPOT
48+
DPNP_FN_INVERT
49+
DPNP_FN_LEFT_SHIFT
3850
DPNP_FN_LOG
3951
DPNP_FN_LOG10
4052
DPNP_FN_LOG1P
4153
DPNP_FN_LOG2
4254
DPNP_FN_MATMUL
55+
DPNP_FN_MATRIX_RANK
4356
DPNP_FN_MAX
4457
DPNP_FN_MAXIMUM
4558
DPNP_FN_MEAN
4659
DPNP_FN_MEDIAN
4760
DPNP_FN_MIN
4861
DPNP_FN_MINIMUM
62+
DPNP_FN_MODF
4963
DPNP_FN_MULTIPLY
5064
DPNP_FN_POWER
5165
DPNP_FN_PROD
52-
DPNP_FN_UNIFORM
5366
DPNP_FN_RADIANS
67+
DPNP_FN_REMAINDER
5468
DPNP_FN_RECIP
69+
DPNP_FN_RIGHT_SHIFT
70+
DPNP_FN_RNG_BETA
71+
DPNP_FN_RNG_BINOMIAL
72+
DPNP_FN_RNG_CHISQUARE
73+
DPNP_FN_RNG_EXPONENTIAL
74+
DPNP_FN_RNG_GAMMA
75+
DPNP_FN_RNG_GAUSSIAN
76+
DPNP_FN_RNG_GEOMETRIC
77+
DPNP_FN_RNG_GUMBEL
78+
DPNP_FN_RNG_HYPERGEOMETRIC
79+
DPNP_FN_RNG_LAPLACE
80+
DPNP_FN_RNG_LOGNORMAL
81+
DPNP_FN_RNG_MULTINOMIAL
82+
DPNP_FN_RNG_MULTIVARIATE_NORMAL
83+
DPNP_FN_RNG_NEGATIVE_BINOMIAL
84+
DPNP_FN_RNG_NORMAL
85+
DPNP_FN_RNG_POISSON
86+
DPNP_FN_RNG_RAYLEIGH
87+
DPNP_FN_RNG_STANDARD_CAUCHY
88+
DPNP_FN_RNG_STANDARD_EXPONENTIAL
89+
DPNP_FN_RNG_STANDARD_GAMMA
90+
DPNP_FN_RNG_STANDARD_NORMAL
91+
DPNP_FN_RNG_UNIFORM
92+
DPNP_FN_RNG_WEIBULL
5593
DPNP_FN_SIGN
5694
DPNP_FN_SIN
5795
DPNP_FN_SINH
@@ -109,6 +147,8 @@ cdef DPNPFuncName get_DPNPFuncName_from_str(name):
109147
return DPNPFuncName.DPNP_FN_ARGSORT
110148
elif name == "dpnp_cov":
111149
return DPNPFuncName.DPNP_FN_COV
150+
elif name == "dpnp_eig":
151+
return DPNPFuncName.DPNP_FN_EIG
112152
else:
113153
return DPNPFuncName.DPNP_FN_DOT
114154

numba_dppy/dpnp_glue/dpnpdecl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from numba.core.typing.templates import (AttributeTemplate, infer_getattr)
2+
import numba_dppy
3+
from numba import types
4+
5+
@infer_getattr
6+
class DppyDpnpTemplate(AttributeTemplate):
7+
key = types.Module(numba_dppy)
8+
9+
def resolve_dpnp(self, mod):
10+
return types.Module(numba_dppy.dpnp)

numba_dppy/dpnp_glue/dpnpimpl.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from numba.core.imputils import lower_builtin
2+
import numba_dppy.experimental_numpy_lowering_overload as dpnp_lowering
3+
from numba import types
4+
from numba.core.typing import signature
5+
from numba.core.extending import overload, register_jitable
6+
from . import stubs
7+
import numpy as np
8+
from numba_dppy.dpctl_functions import _DPCTL_FUNCTIONS
9+
10+
11+
def get_dpnp_fptr(fn_name, type_names):
12+
from . import dpnp_fptr_interface as dpnp_glue
13+
14+
f_ptr = dpnp_glue.get_dpnp_fn_ptr(fn_name, type_names)
15+
return f_ptr
16+
17+
18+
@register_jitable
19+
def _check_finite_matrix(a):
20+
for v in np.nditer(a):
21+
if not np.isfinite(v.item()):
22+
raise np.linalg.LinAlgError("Array must not contain infs or NaNs.")
23+
24+
25+
@register_jitable
26+
def _dummy_liveness_func(a):
27+
"""pass a list of variables to be preserved through dead code elimination"""
28+
return a[0]
29+
30+
31+
class RetrieveDpnpFnPtr(types.ExternalFunctionPointer):
32+
def __init__(self, fn_name, type_names, sig, get_pointer):
33+
self.fn_name = fn_name
34+
self.type_names = type_names
35+
super(RetrieveDpnpFnPtr, self).__init__(sig, get_pointer)
36+
37+
38+
class _DPNP_EXTENSION:
39+
def __init__(self, name):
40+
dpnp_lowering.ensure_dpnp(name)
41+
42+
@classmethod
43+
def dpnp_sum(cls, fn_name, type_names):
44+
ret_type = types.void
45+
sig = signature(ret_type, types.voidptr, types.voidptr, types.int64)
46+
f_ptr = get_dpnp_fptr(fn_name, type_names)
47+
48+
def get_pointer(obj):
49+
return f_ptr
50+
51+
return types.ExternalFunctionPointer(sig, get_pointer=get_pointer)
52+
53+
54+
@overload(stubs.dpnp.sum)
55+
def dpnp_sum_impl(a):
56+
dpnp_extension = _DPNP_EXTENSION("sum")
57+
dpctl_functions = _DPCTL_FUNCTIONS()
58+
59+
dpnp_sum = dpnp_extension.dpnp_sum("dpnp_sum", [a.dtype.name, "NONE"])
60+
61+
get_sycl_queue = dpctl_functions.dpctl_get_current_queue()
62+
allocate_usm_shared = dpctl_functions.dpctl_malloc_shared()
63+
copy_usm = dpctl_functions.dpctl_queue_memcpy()
64+
free_usm = dpctl_functions.dpctl_free_with_queue()
65+
66+
def dpnp_sum_impl(a):
67+
if a.size == 0:
68+
raise ValueError("Passed Empty array")
69+
70+
sycl_queue = get_sycl_queue()
71+
a_usm = allocate_usm_shared(a.size * a.itemsize, sycl_queue)
72+
copy_usm(sycl_queue, a_usm, a.ctypes, a.size * a.itemsize)
73+
74+
out_usm = allocate_usm_shared(a.itemsize, sycl_queue)
75+
76+
dpnp_sum(a_usm, out_usm, a.size)
77+
78+
out = np.empty(1, dtype=a.dtype)
79+
copy_usm(sycl_queue, out.ctypes, out_usm, out.size * out.itemsize)
80+
81+
free_usm(a_usm, sycl_queue)
82+
free_usm(out_usm, sycl_queue)
83+
84+
85+
_dummy_liveness_func([out.size])
86+
87+
return out[0]
88+
89+
return dpnp_sum_impl

numba_dppy/dpnp_glue/stubs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from numba_dppy.ocl.stubs import Stub
2+
3+
class dpnp(Stub):
4+
"""dpnp namespace
5+
"""
6+
_description_ = '<dpnp>'
7+
8+
class sum(Stub):
9+
pass

numba_dppy/dppy_passbuilder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
DPPYNoPythonBackend
2828
)
2929

30+
from .rename_numpy_functions_pass import DPPYRewriteOverloadedFunctions
31+
3032
class DPPYPassBuilder(object):
3133
"""
3234
This is the DPPY pass builder to run Intel GPU/CPU specific
@@ -44,6 +46,11 @@ def default_numba_nopython_pipeline(state, pm):
4446
pm.add_pass(IRProcessing, "processing IR")
4547
pm.add_pass(WithLifting, "Handle with contexts")
4648

49+
# this pass rewrites name of NumPy functions we intend to overload
50+
pm.add_pass(DPPYRewriteOverloadedFunctions,
51+
"Rewrite name of Numpy functions to overload already overloaded function",
52+
)
53+
4754
# this pass adds required logic to overload default implementation of
4855
# Numpy functions
4956
pm.add_pass(DPPYAddNumpyOverloadPass, "dppy add typing template for Numpy functions")

numba_dppy/dppy_passes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44

55
import numpy as np
6+
import numba
67
from numba.core import ir
78
import weakref
89
from collections import namedtuple, deque
@@ -49,7 +50,7 @@ def __init__(self):
4950
def run_pass(self, state):
5051
if dpnp_available():
5152
typingctx = state.typingctx
52-
from numba.core.typing.templates import builtin_registry as reg, infer_global
53+
from numba.core.typing.templates import (builtin_registry as reg, infer_global)
5354
from numba.core.typing.templates import (AbstractTemplate, CallableTemplate, signature)
5455
from numba.core.typing.npydecl import MatMulTyperMixin
5556

0 commit comments

Comments
 (0)