Skip to content

Commit 1d38c5c

Browse files
Add callback mechanism so that array_finalize can see if obj is Numba meminfo. (#214)
* change dparray to numpy_usm_shared. add callback mechanism so that numba_dppy.numpy_usm_shared can register with dpctl.dptensor.numpy_usm_shared.ndarray a callback function to look and see if the object is a Numba MemInfo with USM allocator. * Define sycl usm interface statically. Co-authored-by: Sergey Pokhodenko <sergey.pokhodenko@intel.com>
1 parent 8db2b2c commit 1d38c5c

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
##===---------- dparray.py - dpctl -------*- Python -*----===##
1+
##===---------- numpy_usm_shared.py - dpctl -------*- Python -*----===##
22
##
33
## Data Parallel Control (dpCtl)
44
##
@@ -19,7 +19,7 @@
1919
##===----------------------------------------------------------------------===##
2020
###
2121
### \file
22-
### This file implements a dparray - USM aware implementation of ndarray.
22+
### This file implements a numpy_usm_shared - USM aware implementation of ndarray.
2323
##===----------------------------------------------------------------------===##
2424

2525
import numpy as np
@@ -70,12 +70,17 @@ class ndarray(np.ndarray):
7070
with a foreign allocator.
7171
"""
7272

73+
external_usm_checkers = []
74+
75+
def add_external_usm_checker(func):
76+
ndarray.external_usm_checkers.append(func)
77+
7378
def __new__(
7479
subtype, shape, dtype=float, buffer=None, offset=0, strides=None, order=None
7580
):
7681
# Create a new array.
7782
if buffer is None:
78-
dprint("dparray::ndarray __new__ buffer None")
83+
dprint("numpy_usm_shared::ndarray __new__ buffer None")
7984
nelems = np.prod(shape)
8085
dt = np.dtype(dtype)
8186
isz = dt.itemsize
@@ -102,7 +107,7 @@ def __new__(
102107
return new_obj
103108
# zero copy if buffer is a usm backed array-like thing
104109
elif hasattr(buffer, array_interface_property):
105-
dprint("dparray::ndarray __new__ buffer", array_interface_property)
110+
dprint("numpy_usm_shared::ndarray __new__ buffer", array_interface_property)
106111
# also check for array interface
107112
new_obj = np.ndarray.__new__(
108113
subtype,
@@ -124,7 +129,7 @@ def __new__(
124129
)
125130
return new_obj
126131
else:
127-
dprint("dparray::ndarray __new__ buffer not None and not sycl_usm")
132+
dprint("numpy_usm_shared::ndarray __new__ buffer not None and not sycl_usm")
128133
nelems = np.prod(shape)
129134
# must copy
130135
ar = np.ndarray(
@@ -158,6 +163,9 @@ def __new__(
158163
)
159164
return new_obj
160165

166+
def __sycl_usm_array_interface__(self):
167+
return self._getter_sycl_usm_array_interface()
168+
161169
def _getter_sycl_usm_array_interface_(self):
162170
ary_iface = self.__array_interface__
163171
_base = _get_usm_base(self)
@@ -186,6 +194,9 @@ def __array_finalize__(self, obj):
186194
# subclass of ndarray, including our own.
187195
if hasattr(obj, array_interface_property):
188196
return
197+
for ext_checker in ndarray.external_usm_checkers:
198+
if ext_checker(obj):
199+
return
189200
if isinstance(obj, np.ndarray):
190201
ob = self
191202
while isinstance(ob, np.ndarray):
@@ -200,7 +211,7 @@ def __array_finalize__(self, obj):
200211
)
201212

202213
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
203-
# This way it will use the custom dparray allocator.
214+
# This way it will use the custom numpy_usm_shared allocator.
204215
__numba_no_subtype_ndarray__ = True
205216

206217
# Convert to a NumPy ndarray.
@@ -234,8 +245,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
234245
else:
235246
return NotImplemented
236247
# Have to avoid recursive calls to array_ufunc here.
237-
# If no out kwarg then we create a dparray out so that we get
238-
# USM memory. However, if kwarg has dparray-typed out then
248+
# If no out kwarg then we create a numpy_usm_shared out so that we get
249+
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
239250
# array_ufunc is called recursively so we cast out as regular
240251
# NumPy ndarray (having a USM data pointer).
241252
if kwargs.get("out", None) is None:
@@ -246,7 +257,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
246257
out_as_np = np.ndarray(out.shape, out.dtype, out)
247258
kwargs["out"] = out_as_np
248259
else:
249-
# If they manually gave dparray as out kwarg then we have to also
260+
# If they manually gave numpy_usm_shared as out kwarg then we have to also
250261
# cast as regular NumPy ndarray to avoid recursion.
251262
if isinstance(kwargs["out"], ndarray):
252263
out = kwargs["out"]
@@ -271,7 +282,7 @@ def isdef(x):
271282
cname = c[0]
272283
if isdef(cname):
273284
continue
274-
# For now we do the simple thing and copy the types from NumPy module into dparray module.
285+
# For now we do the simple thing and copy the types from NumPy module into numpy_usm_shared module.
275286
new_func = "%s = np.%s" % (cname, cname)
276287
try:
277288
the_code = compile(new_func, "__init__", "exec")

0 commit comments

Comments
 (0)