Skip to content

Commit 0aa81d2

Browse files
committed
impl iterative print corners
1 parent b4eafff commit 0aa81d2

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

dpctl/tensor/_print.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import contextlib
18+
import itertools
1819
import operator
1920

2021
import numpy as np
@@ -223,10 +224,10 @@ def print_options(*args, **kwargs):
223224

224225

225226
def _nd_corners(arr_in, edge_items):
226-
arr_ndim = arr_in.ndim
227+
_shape = arr_in.shape
227228
res_shape = tuple(
228-
2 * edge_items if arr_in.shape[i] > 2 * edge_items else arr_in.shape[i]
229-
for i in range(arr_ndim)
229+
2 * (edge_items + 1) if _shape[i] > 2 * (edge_items + 1) else _shape[i]
230+
for i in range(arr_in.ndim)
230231
)
231232

232233
arr_out = dpt.empty(
@@ -236,29 +237,27 @@ def _nd_corners(arr_in, edge_items):
236237
sycl_queue=arr_in.sycl_queue,
237238
)
238239

240+
blocks = []
241+
for i in range(len(_shape)):
242+
if _shape[i] > 2 * (edge_items + 1):
243+
blocks.append(
244+
(
245+
np.s_[: edge_items + 1],
246+
np.s_[-edge_items - 1 :],
247+
)
248+
)
249+
else:
250+
blocks.append((np.s_[:],))
251+
239252
hev_list = []
240-
for corner in range(arr_ndim**2):
241-
slices = ()
242-
tmp = bin(corner).replace("0b", "").zfill(arr_ndim)
243-
244-
for dim in reversed(range(arr_ndim)):
245-
if arr_in.shape[dim] < 2 * edge_items:
246-
slices = (np.s_[:],) + slices
247-
else:
248-
ind = (-1) ** int(tmp[dim]) * edge_items
249-
if ind < 0:
250-
slices = (np.s_[-edge_items::],) + slices
251-
else:
252-
slices = (np.s_[:edge_items:],) + slices
253+
for slc in itertools.product(*blocks):
253254
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
254-
src=arr_in[slices],
255-
dst=arr_out[slices],
256-
sycl_queue=arr_in.sycl_queue,
255+
src=arr_in[slc], dst=arr_out[slc], sycl_queue=arr_in.sycl_queue
257256
)
258257
hev_list.append(hev)
259258

260259
dpctl.SyclEvent.wait_for(hev_list)
261-
return arr_out
260+
return dpt.asnumpy(arr_out)
262261

263262

264263
def usm_ndarray_str(
@@ -365,8 +364,7 @@ def usm_ndarray_str(
365364
edge_items = options["edgeitems"]
366365

367366
if x.size > threshold:
368-
# need edge_items + 1 elements for np.array2string to abbreviate
369-
data = dpt.asnumpy(_nd_corners(x, edge_items + 1))
367+
data = _nd_corners(x, edge_items)
370368
options["threshold"] = 0
371369
else:
372370
data = dpt.asnumpy(x)

0 commit comments

Comments
 (0)