1
- ##===---------- dparray .py - dpctl -------*- Python -*----===##
1
+ ##===---------- numpy_usm_shared .py - dpctl -------*- Python -*----===##
2
2
##
3
3
## Data Parallel Control (dpCtl)
4
4
##
19
19
##===----------------------------------------------------------------------===##
20
20
###
21
21
### \file
22
- ### This file implements a dparray - USM aware implementation of ndarray.
22
+ ### This file implements a numpy_usm_shared - USM aware implementation of ndarray.
23
23
##===----------------------------------------------------------------------===##
24
24
25
25
import numpy as np
@@ -70,12 +70,17 @@ class ndarray(np.ndarray):
70
70
with a foreign allocator.
71
71
"""
72
72
73
+ external_usm_checkers = []
74
+
75
+ def add_external_usm_checker (func ):
76
+ ndarray .external_usm_checkers .append (func )
77
+
73
78
def __new__ (
74
79
subtype , shape , dtype = float , buffer = None , offset = 0 , strides = None , order = None
75
80
):
76
81
# Create a new array.
77
82
if buffer is None :
78
- dprint ("dparray ::ndarray __new__ buffer None" )
83
+ dprint ("numpy_usm_shared ::ndarray __new__ buffer None" )
79
84
nelems = np .prod (shape )
80
85
dt = np .dtype (dtype )
81
86
isz = dt .itemsize
@@ -102,7 +107,7 @@ def __new__(
102
107
return new_obj
103
108
# zero copy if buffer is a usm backed array-like thing
104
109
elif hasattr (buffer , array_interface_property ):
105
- dprint ("dparray ::ndarray __new__ buffer" , array_interface_property )
110
+ dprint ("numpy_usm_shared ::ndarray __new__ buffer" , array_interface_property )
106
111
# also check for array interface
107
112
new_obj = np .ndarray .__new__ (
108
113
subtype ,
@@ -124,7 +129,7 @@ def __new__(
124
129
)
125
130
return new_obj
126
131
else :
127
- dprint ("dparray ::ndarray __new__ buffer not None and not sycl_usm" )
132
+ dprint ("numpy_usm_shared ::ndarray __new__ buffer not None and not sycl_usm" )
128
133
nelems = np .prod (shape )
129
134
# must copy
130
135
ar = np .ndarray (
@@ -158,6 +163,9 @@ def __new__(
158
163
)
159
164
return new_obj
160
165
166
+ def __sycl_usm_array_interface__ (self ):
167
+ return self ._getter_sycl_usm_array_interface ()
168
+
161
169
def _getter_sycl_usm_array_interface_ (self ):
162
170
ary_iface = self .__array_interface__
163
171
_base = _get_usm_base (self )
@@ -186,6 +194,9 @@ def __array_finalize__(self, obj):
186
194
# subclass of ndarray, including our own.
187
195
if hasattr (obj , array_interface_property ):
188
196
return
197
+ for ext_checker in ndarray .external_usm_checkers :
198
+ if ext_checker (obj ):
199
+ return
189
200
if isinstance (obj , np .ndarray ):
190
201
ob = self
191
202
while isinstance (ob , np .ndarray ):
@@ -200,7 +211,7 @@ def __array_finalize__(self, obj):
200
211
)
201
212
202
213
# Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
203
- # This way it will use the custom dparray allocator.
214
+ # This way it will use the custom numpy_usm_shared allocator.
204
215
__numba_no_subtype_ndarray__ = True
205
216
206
217
# Convert to a NumPy ndarray.
@@ -234,8 +245,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
234
245
else :
235
246
return NotImplemented
236
247
# Have to avoid recursive calls to array_ufunc here.
237
- # If no out kwarg then we create a dparray out so that we get
238
- # USM memory. However, if kwarg has dparray -typed out then
248
+ # If no out kwarg then we create a numpy_usm_shared out so that we get
249
+ # USM memory. However, if kwarg has numpy_usm_shared -typed out then
239
250
# array_ufunc is called recursively so we cast out as regular
240
251
# NumPy ndarray (having a USM data pointer).
241
252
if kwargs .get ("out" , None ) is None :
@@ -246,7 +257,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
246
257
out_as_np = np .ndarray (out .shape , out .dtype , out )
247
258
kwargs ["out" ] = out_as_np
248
259
else :
249
- # If they manually gave dparray as out kwarg then we have to also
260
+ # If they manually gave numpy_usm_shared as out kwarg then we have to also
250
261
# cast as regular NumPy ndarray to avoid recursion.
251
262
if isinstance (kwargs ["out" ], ndarray ):
252
263
out = kwargs ["out" ]
@@ -271,7 +282,7 @@ def isdef(x):
271
282
cname = c [0 ]
272
283
if isdef (cname ):
273
284
continue
274
- # For now we do the simple thing and copy the types from NumPy module into dparray module.
285
+ # For now we do the simple thing and copy the types from NumPy module into numpy_usm_shared module.
275
286
new_func = "%s = np.%s" % (cname , cname )
276
287
try :
277
288
the_code = compile (new_func , "__init__" , "exec" )
0 commit comments