Skip to content

Add callback mechanism so that array_finalize can see if obj is Numba meminfo. #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 11, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions dpctl/dptensor/numpy_usm_shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
##===---------- dparray.py - dpctl -------*- Python -*----===##
##===---------- numpy_usm_shared.py - dpctl -------*- Python -*----===##
##
## Data Parallel Control (dpCtl)
##
Expand All @@ -19,7 +19,7 @@
##===----------------------------------------------------------------------===##
###
### \file
### This file implements a dparray - USM aware implementation of ndarray.
### This file implements a numpy_usm_shared - USM aware implementation of ndarray.
##===----------------------------------------------------------------------===##

import numpy as np
Expand Down Expand Up @@ -69,13 +69,17 @@ class ndarray(np.ndarray):
numpy.ndarray subclass whose underlying memory buffer is allocated
with a foreign allocator.
"""
external_usm_checkers = []

def add_external_usm_checker(func):
ndarray.external_usm_checkers.append(func)

def __new__(
subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None
):
# Create a new array.
if buffer is None:
dprint("dparray::ndarray __new__ buffer None")
dprint("numpy_usm_shared::ndarray __new__ buffer None")
nelems = np.prod(shape)
dt = np.dtype(dtype)
isz = dt.itemsize
Expand All @@ -102,7 +106,7 @@ def __new__(
return new_obj
# zero copy if buffer is a usm backed array-like thing
elif hasattr(buffer, array_interface_property):
dprint("dparray::ndarray __new__ buffer", array_interface_property)
dprint("numpy_usm_shared::ndarray __new__ buffer", array_interface_property)
# also check for array interface
new_obj = np.ndarray.__new__(
subtype,
Expand All @@ -124,7 +128,7 @@ def __new__(
)
return new_obj
else:
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
dprint("numpy_usm_shared::ndarray __new__ buffer not None and not sycl_usm")
nelems = np.prod(shape)
# must copy
ar = np.ndarray(
Expand Down Expand Up @@ -158,6 +162,9 @@ def __new__(
)
return new_obj

def __sycl_usm_array_interface__(self):
return self._getter_sycl_usm_array_interface()

def _getter_sycl_usm_array_interface_(self):
ary_iface = self.__array_interface__
_base = _get_usm_base(self)
Expand Down Expand Up @@ -186,6 +193,9 @@ def __array_finalize__(self, obj):
# subclass of ndarray, including our own.
if hasattr(obj, array_interface_property):
return
for ext_checker in ndarray.external_usm_checkers:
if ext_checker(obj):
return
if isinstance(obj, np.ndarray):
ob = self
while isinstance(ob, np.ndarray):
Expand All @@ -200,7 +210,7 @@ def __array_finalize__(self, obj):
)

# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
# This way it will use the custom dparray allocator.
# This way it will use the custom numpy_usm_shared allocator.
__numba_no_subtype_ndarray__ = True

# Convert to a NumPy ndarray.
Expand Down Expand Up @@ -234,8 +244,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
return NotImplemented
# Have to avoid recursive calls to array_ufunc here.
# If no out kwarg then we create a dparray out so that we get
# USM memory. However, if kwarg has dparray-typed out then
# If no out kwarg then we create a numpy_usm_shared out so that we get
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
# array_ufunc is called recursively so we cast out as regular
# NumPy ndarray (having a USM data pointer).
if kwargs.get("out", None) is None:
Expand All @@ -246,7 +256,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out_as_np = np.ndarray(out.shape, out.dtype, out)
kwargs["out"] = out_as_np
else:
# If they manually gave dparray as out kwarg then we have to also
# If they manually gave numpy_usm_shared as out kwarg then we have to also
# cast as regular NumPy ndarray to avoid recursion.
if isinstance(kwargs["out"], ndarray):
out = kwargs["out"]
Expand All @@ -271,7 +281,7 @@ def isdef(x):
cname = c[0]
if isdef(cname):
continue
# For now we do the simple thing and copy the types from NumPy module into dparray module.
# For now we do the simple thing and copy the types from NumPy module into numpy_usm_shared module.
new_func = "%s = np.%s" % (cname, cname)
try:
the_code = compile(new_func, "__init__", "exec")
Expand Down