-
Notifications
You must be signed in to change notification settings - Fork 32
Pass to rewrite Numpy function names to be able to overload them for Numba-dppy pipeline #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
diptorupd
merged 16 commits into
IntelPython:main
from
reazulhoque:feature/rewrite_pass_to_rename_functions
Dec 9, 2020
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
f5ea456
Sum example
reazulhoque 1d346cd
Moved from infer_type, lower_builtin to overload
reazulhoque 3dd9f24
Added two level module name functions
reazulhoque 0003586
Remove cython generated file
reazulhoque 1f85af9
Module name fix for moving to new extension
reazulhoque a41b9a2
Incomplete linalg.eig implementation
reazulhoque 0c9dfcb
Merge branch 'main' into feature/rewrite_pass_to_rename_functions
reazulhoque 4cd5c82
Updated dppl to dppy
reazulhoque 55ff896
Updted all dppl to dppy and moved rewrite_numpy_function_pass to it's…
reazulhoque 687a52a
Import module at correct locations
reazulhoque 0fc597a
Added comments
reazulhoque 35c22ef
Added test and updated comments
reazulhoque ab646b6
Revert unneeded changes
reazulhoque 10b90f1
Update Eigen implementation
reazulhoque e65b6cd
Remove eig implementation
reazulhoque a3524ca
Add checking equivalent IR
reazulhoque File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from numba import types | ||
from numba.core.typing import signature | ||
|
||
|
||
class _DPCTL_FUNCTIONS: | ||
@classmethod | ||
def dpctl_get_current_queue(cls): | ||
ret_type = types.voidptr | ||
sig = signature(ret_type) | ||
return types.ExternalFunction("DPCTLQueueMgr_GetCurrentQueue", sig) | ||
|
||
@classmethod | ||
def dpctl_malloc_shared(cls): | ||
ret_type = types.voidptr | ||
sig = signature(ret_type, types.int64, types.voidptr) | ||
return types.ExternalFunction("DPCTLmalloc_shared", sig) | ||
|
||
@classmethod | ||
def dpctl_queue_memcpy(cls): | ||
ret_type = types.void | ||
sig = signature( | ||
ret_type, types.voidptr, types.voidptr, types.voidptr, types.int64 | ||
) | ||
return types.ExternalFunction("DPCTLQueue_Memcpy", sig) | ||
|
||
@classmethod | ||
def dpctl_free_with_queue(cls): | ||
ret_type = types.void | ||
sig = signature(ret_type, types.voidptr, types.voidptr) | ||
return types.ExternalFunction("DPCTLfree_with_queue", sig) |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from numba.core.typing.templates import (AttributeTemplate, infer_getattr) | ||
import numba_dppy | ||
from numba import types | ||
|
||
@infer_getattr | ||
class DppyDpnpTemplate(AttributeTemplate): | ||
key = types.Module(numba_dppy) | ||
|
||
def resolve_dpnp(self, mod): | ||
return types.Module(numba_dppy.dpnp) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from numba.core.imputils import lower_builtin | ||
import numba_dppy.experimental_numpy_lowering_overload as dpnp_lowering | ||
from numba import types | ||
from numba.core.typing import signature | ||
from numba.core.extending import overload, register_jitable | ||
from . import stubs | ||
import numpy as np | ||
from numba_dppy.dpctl_functions import _DPCTL_FUNCTIONS | ||
|
||
|
||
def get_dpnp_fptr(fn_name, type_names): | ||
from . import dpnp_fptr_interface as dpnp_glue | ||
|
||
f_ptr = dpnp_glue.get_dpnp_fn_ptr(fn_name, type_names) | ||
return f_ptr | ||
|
||
|
||
@register_jitable | ||
def _check_finite_matrix(a): | ||
for v in np.nditer(a): | ||
if not np.isfinite(v.item()): | ||
raise np.linalg.LinAlgError("Array must not contain infs or NaNs.") | ||
|
||
|
||
@register_jitable | ||
def _dummy_liveness_func(a): | ||
"""pass a list of variables to be preserved through dead code elimination""" | ||
return a[0] | ||
|
||
|
||
class RetrieveDpnpFnPtr(types.ExternalFunctionPointer): | ||
def __init__(self, fn_name, type_names, sig, get_pointer): | ||
self.fn_name = fn_name | ||
self.type_names = type_names | ||
super(RetrieveDpnpFnPtr, self).__init__(sig, get_pointer) | ||
|
||
|
||
class _DPNP_EXTENSION: | ||
def __init__(self, name): | ||
dpnp_lowering.ensure_dpnp(name) | ||
|
||
@classmethod | ||
def dpnp_sum(cls, fn_name, type_names): | ||
ret_type = types.void | ||
sig = signature(ret_type, types.voidptr, types.voidptr, types.int64) | ||
f_ptr = get_dpnp_fptr(fn_name, type_names) | ||
|
||
def get_pointer(obj): | ||
return f_ptr | ||
|
||
return types.ExternalFunctionPointer(sig, get_pointer=get_pointer) | ||
|
||
|
||
@overload(stubs.dpnp.sum) | ||
def dpnp_sum_impl(a): | ||
dpnp_extension = _DPNP_EXTENSION("sum") | ||
dpctl_functions = _DPCTL_FUNCTIONS() | ||
|
||
dpnp_sum = dpnp_extension.dpnp_sum("dpnp_sum", [a.dtype.name, "NONE"]) | ||
|
||
get_sycl_queue = dpctl_functions.dpctl_get_current_queue() | ||
allocate_usm_shared = dpctl_functions.dpctl_malloc_shared() | ||
copy_usm = dpctl_functions.dpctl_queue_memcpy() | ||
free_usm = dpctl_functions.dpctl_free_with_queue() | ||
|
||
def dpnp_sum_impl(a): | ||
if a.size == 0: | ||
raise ValueError("Passed Empty array") | ||
|
||
sycl_queue = get_sycl_queue() | ||
a_usm = allocate_usm_shared(a.size * a.itemsize, sycl_queue) | ||
copy_usm(sycl_queue, a_usm, a.ctypes, a.size * a.itemsize) | ||
|
||
out_usm = allocate_usm_shared(a.itemsize, sycl_queue) | ||
|
||
dpnp_sum(a_usm, out_usm, a.size) | ||
|
||
out = np.empty(1, dtype=a.dtype) | ||
copy_usm(sycl_queue, out.ctypes, out_usm, out.size * out.itemsize) | ||
|
||
free_usm(a_usm, sycl_queue) | ||
free_usm(out_usm, sycl_queue) | ||
|
||
|
||
_dummy_liveness_func([out.size]) | ||
|
||
return out[0] | ||
|
||
return dpnp_sum_impl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from numba_dppy.ocl.stubs import Stub | ||
|
||
class dpnp(Stub): | ||
"""dpnp namespace | ||
""" | ||
_description_ = '<dpnp>' | ||
|
||
class sum(Stub): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.