Closed
Description
Goal
Allow users to call the following Numpy functions (initially) inside function decorated with @njit
to be offloaded to the GPU.
- np.sum (upto 3d tested)
- np.prod (upto 3d tested)
- np.argmax (upto 3d tested)
- np.max (upto 3d tested)
- np.argmin (upto 3d tested)
- np.min (upto 3d tested)
- np.argsort
- np.median (upto 3d tested)
- np.mean (upto 3d tested)
- np.matmul
- np.dot
- np.cov
Issues
- Numba already has overloads for these functions, the existing overloads can not be plugged in and out.
- Numba uses a mix of old and new techniques to implement these functions (1.
@infer_type and @lower_builtin
2.@overload
). Ideally we want to reuse most of Numba's existing implementation and simply make a call to dpnp.\- Using
@overload
does not work as we can not specify which overload to use. - Using the first techinque requires us to know the Numpy function very well in order to know the combination of the typing for each set of input which is not trivial and time consuming to get right. The mechanism to override existing typing and implementation is hackish.
- Using
Possible solution
- Implement a pass that rewrites the name of the selected Numpy functions to avoid having to deal with existing overload.
- Port current implementation of calling into dpnp from technique 1 (
@infer_type and @lower_builtin
) to technique 2 (@overload
)
Current implementation:
https://github.com/IntelPython/numba-dppy/blob/main/numba_dppy/experimental_numpy_lowering_overload.py
POC of possible solution: