15
15
# limitations under the License.
16
16
17
17
import contextlib
18
+ import itertools
18
19
import operator
19
20
20
21
import numpy as np
@@ -223,10 +224,10 @@ def print_options(*args, **kwargs):
223
224
224
225
225
226
def _nd_corners (arr_in , edge_items ):
226
- arr_ndim = arr_in .ndim
227
+ _shape = arr_in .shape
227
228
res_shape = tuple (
228
- 2 * edge_items if arr_in . shape [i ] > 2 * edge_items else arr_in . shape [i ]
229
- for i in range (arr_ndim )
229
+ 2 * ( edge_items + 1 ) if _shape [i ] > 2 * ( edge_items + 1 ) else _shape [i ]
230
+ for i in range (arr_in . ndim )
230
231
)
231
232
232
233
arr_out = dpt .empty (
@@ -236,29 +237,27 @@ def _nd_corners(arr_in, edge_items):
236
237
sycl_queue = arr_in .sycl_queue ,
237
238
)
238
239
240
+ blocks = []
241
+ for i in range (len (_shape )):
242
+ if _shape [i ] > 2 * (edge_items + 1 ):
243
+ blocks .append (
244
+ (
245
+ np .s_ [: edge_items + 1 ],
246
+ np .s_ [- edge_items - 1 :],
247
+ )
248
+ )
249
+ else :
250
+ blocks .append ((np .s_ [:],))
251
+
239
252
hev_list = []
240
- for corner in range (arr_ndim ** 2 ):
241
- slices = ()
242
- tmp = bin (corner ).replace ("0b" , "" ).zfill (arr_ndim )
243
-
244
- for dim in reversed (range (arr_ndim )):
245
- if arr_in .shape [dim ] < 2 * edge_items :
246
- slices = (np .s_ [:],) + slices
247
- else :
248
- ind = (- 1 ) ** int (tmp [dim ]) * edge_items
249
- if ind < 0 :
250
- slices = (np .s_ [- edge_items ::],) + slices
251
- else :
252
- slices = (np .s_ [:edge_items :],) + slices
253
+ for slc in itertools .product (* blocks ):
253
254
hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
254
- src = arr_in [slices ],
255
- dst = arr_out [slices ],
256
- sycl_queue = arr_in .sycl_queue ,
255
+ src = arr_in [slc ], dst = arr_out [slc ], sycl_queue = arr_in .sycl_queue
257
256
)
258
257
hev_list .append (hev )
259
258
260
259
dpctl .SyclEvent .wait_for (hev_list )
261
- return arr_out
260
+ return dpt . asnumpy ( arr_out )
262
261
263
262
264
263
def usm_ndarray_str (
@@ -365,8 +364,7 @@ def usm_ndarray_str(
365
364
edge_items = options ["edgeitems" ]
366
365
367
366
if x .size > threshold :
368
- # need edge_items + 1 elements for np.array2string to abbreviate
369
- data = dpt .asnumpy (_nd_corners (x , edge_items + 1 ))
367
+ data = _nd_corners (x , edge_items )
370
368
options ["threshold" ] = 0
371
369
else :
372
370
data = dpt .asnumpy (x )
0 commit comments