Skip to content

Commit 64fa95e

Browse files
Deployed _copy_usm_ndarray_for_roll_nd in dpt.roll
1 parent 987af9e commit 64fa95e

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616

1717

18-
from itertools import chain, product, repeat
18+
import operator
19+
from itertools import chain, repeat
1920

2021
import numpy as np
2122
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
@@ -426,6 +427,7 @@ def roll(X, shift, axis=None):
426427
if not isinstance(X, dpt.usm_ndarray):
427428
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
428429
if axis is None:
430+
shift = operator.index(shift)
429431
res = dpt.empty(
430432
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
431433
)
@@ -438,31 +440,20 @@ def roll(X, shift, axis=None):
438440
broadcasted = np.broadcast(shift, axis)
439441
if broadcasted.ndim > 1:
440442
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
441-
shifts = {ax: 0 for ax in range(X.ndim)}
443+
shifts = [
444+
0,
445+
] * X.ndim
442446
for sh, ax in broadcasted:
443447
shifts[ax] += sh
444-
rolls = [((np.s_[:], np.s_[:]),)] * X.ndim
445-
for ax, offset in shifts.items():
446-
offset %= X.shape[ax] or 1
447-
if offset:
448-
# (original, result), (original, result)
449-
rolls[ax] = (
450-
(np.s_[:-offset], np.s_[offset:]),
451-
(np.s_[-offset:], np.s_[:offset]),
452-
)
453448

449+
exec_q = X.sycl_queue
454450
res = dpt.empty(
455-
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
451+
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
456452
)
457-
hev_list = []
458-
for indices in product(*rolls):
459-
arr_index, res_index = zip(*indices)
460-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
461-
src=X[arr_index], dst=res[res_index], sycl_queue=X.sycl_queue
462-
)
463-
hev_list.append(hev)
464-
465-
dpctl.SyclEvent.wait_for(hev_list)
453+
ht_e, _ = ti._copy_usm_ndarray_for_roll_nd(
454+
src=X, dst=res, shifts=shifts, sycl_queue=exec_q
455+
)
456+
ht_e.wait()
466457
return res
467458

468459

0 commit comments

Comments
 (0)