Skip to content

Commit e24fe4e

Browse files
committed
impl iterative print corners
1 parent b4eafff commit e24fe4e

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

dpctl/tensor/_print.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import contextlib
18+
import itertools
1819
import operator
1920

2021
import numpy as np
@@ -223,10 +224,11 @@ def print_options(*args, **kwargs):
223224

224225

225226
def _nd_corners(arr_in, edge_items):
226-
arr_ndim = arr_in.ndim
227+
_shape = arr_in.shape
228+
max_shape = 2 * edge_items + 1
227229
res_shape = tuple(
228-
2 * edge_items if arr_in.shape[i] > 2 * edge_items else arr_in.shape[i]
229-
for i in range(arr_ndim)
230+
max_shape if _shape[i] > max_shape else _shape[i]
231+
for i in range(arr_in.ndim)
230232
)
231233

232234
arr_out = dpt.empty(
@@ -236,29 +238,27 @@ def _nd_corners(arr_in, edge_items):
236238
sycl_queue=arr_in.sycl_queue,
237239
)
238240

241+
blocks = []
242+
for i in range(len(_shape)):
243+
if _shape[i] > max_shape:
244+
blocks.append(
245+
(
246+
np.s_[:edge_items],
247+
np.s_[-edge_items:],
248+
)
249+
)
250+
else:
251+
blocks.append((np.s_[:],))
252+
239253
hev_list = []
240-
for corner in range(arr_ndim**2):
241-
slices = ()
242-
tmp = bin(corner).replace("0b", "").zfill(arr_ndim)
243-
244-
for dim in reversed(range(arr_ndim)):
245-
if arr_in.shape[dim] < 2 * edge_items:
246-
slices = (np.s_[:],) + slices
247-
else:
248-
ind = (-1) ** int(tmp[dim]) * edge_items
249-
if ind < 0:
250-
slices = (np.s_[-edge_items::],) + slices
251-
else:
252-
slices = (np.s_[:edge_items:],) + slices
254+
for slc in itertools.product(*blocks):
253255
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
254-
src=arr_in[slices],
255-
dst=arr_out[slices],
256-
sycl_queue=arr_in.sycl_queue,
256+
src=arr_in[slc], dst=arr_out[slc], sycl_queue=arr_in.sycl_queue
257257
)
258258
hev_list.append(hev)
259259

260260
dpctl.SyclEvent.wait_for(hev_list)
261-
return arr_out
261+
return dpt.asnumpy(arr_out)
262262

263263

264264
def usm_ndarray_str(
@@ -365,8 +365,7 @@ def usm_ndarray_str(
365365
edge_items = options["edgeitems"]
366366

367367
if x.size > threshold:
368-
# need edge_items + 1 elements for np.array2string to abbreviate
369-
data = dpt.asnumpy(_nd_corners(x, edge_items + 1))
368+
data = _nd_corners(x, edge_items)
370369
options["threshold"] = 0
371370
else:
372371
data = dpt.asnumpy(x)

0 commit comments

Comments
 (0)