15
15
# limitations under the License.
16
16
17
17
import contextlib
18
+ import itertools
18
19
import operator
19
20
20
21
import numpy as np
@@ -223,10 +224,11 @@ def print_options(*args, **kwargs):
223
224
224
225
225
226
def _nd_corners (arr_in , edge_items ):
226
- arr_ndim = arr_in .ndim
227
+ _shape = arr_in .shape
228
+ max_shape = 2 * edge_items + 1
227
229
res_shape = tuple (
228
- 2 * edge_items if arr_in . shape [i ] > 2 * edge_items else arr_in . shape [i ]
229
- for i in range (arr_ndim )
230
+ max_shape if _shape [i ] > max_shape else _shape [i ]
231
+ for i in range (arr_in . ndim )
230
232
)
231
233
232
234
arr_out = dpt .empty (
@@ -236,29 +238,27 @@ def _nd_corners(arr_in, edge_items):
236
238
sycl_queue = arr_in .sycl_queue ,
237
239
)
238
240
241
+ blocks = []
242
+ for i in range (len (_shape )):
243
+ if _shape [i ] > max_shape :
244
+ blocks .append (
245
+ (
246
+ np .s_ [:edge_items ],
247
+ np .s_ [- edge_items :],
248
+ )
249
+ )
250
+ else :
251
+ blocks .append ((np .s_ [:],))
252
+
239
253
hev_list = []
240
- for corner in range (arr_ndim ** 2 ):
241
- slices = ()
242
- tmp = bin (corner ).replace ("0b" , "" ).zfill (arr_ndim )
243
-
244
- for dim in reversed (range (arr_ndim )):
245
- if arr_in .shape [dim ] < 2 * edge_items :
246
- slices = (np .s_ [:],) + slices
247
- else :
248
- ind = (- 1 ) ** int (tmp [dim ]) * edge_items
249
- if ind < 0 :
250
- slices = (np .s_ [- edge_items ::],) + slices
251
- else :
252
- slices = (np .s_ [:edge_items :],) + slices
254
+ for slc in itertools .product (* blocks ):
253
255
hev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
254
- src = arr_in [slices ],
255
- dst = arr_out [slices ],
256
- sycl_queue = arr_in .sycl_queue ,
256
+ src = arr_in [slc ], dst = arr_out [slc ], sycl_queue = arr_in .sycl_queue
257
257
)
258
258
hev_list .append (hev )
259
259
260
260
dpctl .SyclEvent .wait_for (hev_list )
261
- return arr_out
261
+ return dpt . asnumpy ( arr_out )
262
262
263
263
264
264
def usm_ndarray_str (
@@ -365,8 +365,7 @@ def usm_ndarray_str(
365
365
edge_items = options ["edgeitems" ]
366
366
367
367
if x .size > threshold :
368
- # need edge_items + 1 elements for np.array2string to abbreviate
369
- data = dpt .asnumpy (_nd_corners (x , edge_items + 1 ))
368
+ data = _nd_corners (x , edge_items )
370
369
options ["threshold" ] = 0
371
370
else :
372
371
data = dpt .asnumpy (x )
0 commit comments