Skip to content

Commit 43482fd

Browse files
committed
impl iterative print corners
1 parent b4eafff commit 43482fd

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

dpctl/tensor/_print.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import contextlib
18+
import itertools
1819
import operator
1920

2021
import numpy as np
@@ -236,20 +237,28 @@ def _nd_corners(arr_in, edge_items):
236237
sycl_queue=arr_in.sycl_queue,
237238
)
238239

240+
split_dim = sum(
241+
arr_in.shape[dim] > 2 * edge_items for dim in range(arr_ndim)
242+
)
243+
blocks = [
244+
list(ele) for ele in list(itertools.product(*[[0, 1]] * split_dim))
245+
]
246+
for dim in range(arr_ndim):
247+
if arr_in.shape[dim] <= 2 * edge_items:
248+
for blk in blocks:
249+
blk.insert(dim, 0)
250+
239251
hev_list = []
240-
for corner in range(arr_ndim**2):
252+
for blk in blocks:
241253
slices = ()
242-
tmp = bin(corner).replace("0b", "").zfill(arr_ndim)
243-
244254
for dim in reversed(range(arr_ndim)):
245-
if arr_in.shape[dim] < 2 * edge_items:
246-
slices = (np.s_[:],) + slices
247-
else:
248-
ind = (-1) ** int(tmp[dim]) * edge_items
249-
if ind < 0:
255+
if arr_in.shape[dim] > 2 * edge_items:
256+
if blk[dim] == 1:
250257
slices = (np.s_[-edge_items::],) + slices
251258
else:
252259
slices = (np.s_[:edge_items:],) + slices
260+
else:
261+
slices = (np.s_[:],) + slices
253262
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
254263
src=arr_in[slices],
255264
dst=arr_out[slices],

0 commit comments

Comments
 (0)