Skip to content

Commit f492493

Browse files
Merge pull request #1778 from IntelPython/take-along-axis
2 parents b95988d + d07de32 commit f492493

File tree

6 files changed

+259
-7
lines changed

6 files changed

+259
-7
lines changed

docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ by either integral arrays of indices or boolean mask arrays.
1515
place
1616
put
1717
take
18+
take_along_axis

dpctl/tensor/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,14 @@
6060
)
6161
from dpctl.tensor._device import Device
6262
from dpctl.tensor._dlpack import from_dlpack
63-
from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take
63+
from dpctl.tensor._indexing_functions import (
64+
extract,
65+
nonzero,
66+
place,
67+
put,
68+
take,
69+
take_along_axis,
70+
)
6471
from dpctl.tensor._linear_algebra_functions import (
6572
matmul,
6673
matrix_transpose,
@@ -376,4 +383,5 @@
376383
"nextafter",
377384
"diff",
378385
"count_nonzero",
386+
"take_along_axis",
379387
]

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,13 +795,18 @@ def _nonzero_impl(ary):
795795
return res
796796

797797

798-
def _take_multi_index(ary, inds, p):
798+
def _take_multi_index(ary, inds, p, mode=0):
799799
if not isinstance(ary, dpt.usm_ndarray):
800800
raise TypeError(
801801
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
802802
)
803803
ary_nd = ary.ndim
804804
p = normalize_axis_index(operator.index(p), ary_nd)
805+
mode = operator.index(mode)
806+
if mode not in [0, 1]:
807+
raise ValueError(
808+
"Invalid value for mode keyword, only 0 or 1 is supported"
809+
)
805810
queues_ = [
806811
ary.sycl_queue,
807812
]
@@ -860,7 +865,7 @@ def _take_multi_index(ary, inds, p):
860865
ind=inds,
861866
dst=res,
862867
axis_start=p,
863-
mode=0,
868+
mode=mode,
864869
sycl_queue=exec_q,
865870
depends=dep_ev,
866871
)

dpctl/tensor/_indexing_functions.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.utils
2323

24-
from ._copy_utils import _extract_impl, _nonzero_impl
24+
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
2525
from ._numpy_helper import normalize_axis_index
2626

2727

@@ -423,3 +423,82 @@ def nonzero(arr):
423423
if arr.ndim == 0:
424424
raise ValueError("Array of positive rank is expected")
425425
return _nonzero_impl(arr)
426+
427+
428+
def _range(sh_i, i, nd, q, usm_t, dt):
429+
ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q)
430+
ind.shape = tuple(sh_i if i == j else 1 for j in range(nd))
431+
return ind
432+
433+
434+
def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
435+
"""
436+
Returns elements from an array at the one-dimensional indices specified
437+
by ``indices`` along a provided ``axis``.
438+
439+
Args:
440+
x (usm_ndarray):
441+
input array. Must be compatible with ``indices``, except for the
442+
axis (dimension) specified by ``axis``.
443+
indices (usm_ndarray):
444+
array indices. Must have the same rank (i.e., number of dimensions)
445+
as ``x``.
446+
axis: int
447+
axis along which to select values. If ``axis`` is negative, the
448+
function determines the axis along which to select values by
449+
counting from the last dimension. Default: ``-1``.
450+
mode (str, optional):
451+
How out-of-bounds indices will be handled. Possible values
452+
are:
453+
454+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
455+
negative indices.
456+
- ``"clip"``: clips indices to (``0 <= i < n``).
457+
458+
Default: ``"wrap"``.
459+
460+
Returns:
461+
usm_ndarray:
462+
an array having the same data type as ``x``. The returned array has
463+
the same rank (i.e., number of dimensions) as ``x`` and a shape
464+
determined according to :ref:`broadcasting`, except for the axis
465+
(dimension) specified by ``axis`` whose size must equal the size
466+
of the corresponding axis (dimension) in ``indices``.
467+
468+
Note:
469+
Treatment of the out-of-bound indices in ``indices`` array is controlled
470+
by the value of ``mode`` keyword.
471+
"""
472+
if not isinstance(x, dpt.usm_ndarray):
473+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
474+
if not isinstance(indices, dpt.usm_ndarray):
475+
raise TypeError(
476+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
477+
)
478+
x_nd = x.ndim
479+
if x_nd != indices.ndim:
480+
raise ValueError(
481+
"Number of dimensions in the first and the second "
482+
"argument arrays must be equal"
483+
)
484+
pp = normalize_axis_index(operator.index(axis), x_nd)
485+
out_usm_type = dpctl.utils.get_coerced_usm_type(
486+
(x.usm_type, indices.usm_type)
487+
)
488+
exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue))
489+
if exec_q is None:
490+
raise dpctl.utils.ExecutionPlacementError(
491+
"Execution placement can not be unambiguously inferred "
492+
"from input arguments. "
493+
)
494+
mode_i = _get_indexing_mode(mode)
495+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
496+
_ind = tuple(
497+
(
498+
indices
499+
if i == pp
500+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
501+
)
502+
for i in range(x_nd)
503+
)
504+
return _take_multi_index(x, _ind, 0, mode=mode_i)

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,5 +1535,151 @@ def test_advanced_integer_indexing_cast_indices():
15351535
inds1 = dpt.astype(inds0, "u4")
15361536
inds2 = dpt.astype(inds0, "u8")
15371537
x = dpt.ones((3, 4, 5, 6), dtype="i4")
1538+
# test getitem
15381539
with pytest.raises(ValueError):
15391540
x[inds0, inds1, inds2, ...]
1541+
# test setitem
1542+
with pytest.raises(ValueError):
1543+
x[inds0, inds1, inds2, ...] = 1
1544+
1545+
1546+
def test_take_along_axis():
1547+
get_queue_or_skip()
1548+
1549+
n0, n1, n2 = 3, 5, 7
1550+
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
1551+
ind_dt = dpt.__array_namespace_info__().default_dtypes(
1552+
device=x.sycl_device
1553+
)["indexing"]
1554+
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
1555+
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
1556+
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)
1557+
1558+
y0 = dpt.take_along_axis(x, ind0, axis=0)
1559+
assert y0.shape == ind0.shape
1560+
y1 = dpt.take_along_axis(x, ind1, axis=1)
1561+
assert y1.shape == ind1.shape
1562+
y2 = dpt.take_along_axis(x, ind2, axis=2)
1563+
assert y2.shape == ind2.shape
1564+
1565+
1566+
def test_take_along_axis_validation():
1567+
# type check on the first argument
1568+
with pytest.raises(TypeError):
1569+
dpt.take_along_axis(tuple(), list())
1570+
get_queue_or_skip()
1571+
n1, n2 = 2, 5
1572+
x = dpt.ones(n1 * n2)
1573+
# type check on the second argument
1574+
with pytest.raises(TypeError):
1575+
dpt.take_along_axis(x, list())
1576+
x_dev = x.sycl_device
1577+
info_ = dpt.__array_namespace_info__()
1578+
def_dtypes = info_.default_dtypes(device=x_dev)
1579+
ind_dt = def_dtypes["indexing"]
1580+
ind = dpt.zeros(1, dtype=ind_dt)
1581+
# axis valudation
1582+
with pytest.raises(ValueError):
1583+
dpt.take_along_axis(x, ind, axis=1)
1584+
# mode validation
1585+
with pytest.raises(ValueError):
1586+
dpt.take_along_axis(x, ind, axis=0, mode="invalid")
1587+
# same array-ranks validation
1588+
with pytest.raises(ValueError):
1589+
dpt.take_along_axis(dpt.reshape(x, (n1, n2)), ind)
1590+
# check compute-follows-data
1591+
q2 = dpctl.SyclQueue(x_dev, property="enable_profiling")
1592+
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
1593+
with pytest.raises(ExecutionPlacementError):
1594+
dpt.take_along_axis(x, ind2)
1595+
1596+
1597+
def check__extract_impl_validation(fn):
1598+
x = dpt.ones(10)
1599+
ind = dpt.ones(10, dtype="?")
1600+
with pytest.raises(TypeError):
1601+
fn(list(), ind)
1602+
with pytest.raises(TypeError):
1603+
fn(x, list())
1604+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1605+
ind2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1606+
with pytest.raises(ExecutionPlacementError):
1607+
fn(x, ind2)
1608+
with pytest.raises(ValueError):
1609+
fn(x, ind, 1)
1610+
1611+
1612+
def check__nonzero_impl_validation(fn):
1613+
with pytest.raises(TypeError):
1614+
fn(list())
1615+
1616+
1617+
def check__take_multi_index(fn):
1618+
x = dpt.ones(10)
1619+
x_dev = x.sycl_device
1620+
info_ = dpt.__array_namespace_info__()
1621+
def_dtypes = info_.default_dtypes(device=x_dev)
1622+
ind_dt = def_dtypes["indexing"]
1623+
ind = dpt.arange(10, dtype=ind_dt)
1624+
with pytest.raises(TypeError):
1625+
fn(list(), tuple(), 1)
1626+
with pytest.raises(ValueError):
1627+
fn(x, (ind,), 0, mode=2)
1628+
with pytest.raises(ValueError):
1629+
fn(x, (None,), 1)
1630+
with pytest.raises(IndexError):
1631+
fn(x, (x,), 1)
1632+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1633+
ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2)
1634+
with pytest.raises(ExecutionPlacementError):
1635+
fn(x, (ind2,), 0)
1636+
m = dpt.ones((10, 10))
1637+
ind_1 = dpt.arange(10, dtype="i8")
1638+
ind_2 = dpt.arange(10, dtype="u8")
1639+
with pytest.raises(ValueError):
1640+
fn(m, (ind_1, ind_2), 0)
1641+
1642+
1643+
def check__place_impl_validation(fn):
1644+
with pytest.raises(TypeError):
1645+
fn(list(), list(), list())
1646+
x = dpt.ones(10)
1647+
with pytest.raises(TypeError):
1648+
fn(x, list(), list())
1649+
q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling")
1650+
mask2 = dpt.ones(10, dtype="?", sycl_queue=q2)
1651+
with pytest.raises(ExecutionPlacementError):
1652+
fn(x, mask2, 1)
1653+
x2 = dpt.ones((5, 5))
1654+
mask2 = dpt.ones((5, 5), dtype="?")
1655+
with pytest.raises(ValueError):
1656+
fn(x2, mask2, x2, axis=1)
1657+
1658+
1659+
def check__put_multi_index_validation(fn):
1660+
with pytest.raises(TypeError):
1661+
fn(list(), list(), 0, list())
1662+
x = dpt.ones(10)
1663+
inds = dpt.arange(10, dtype="i8")
1664+
vals = dpt.zeros(10)
1665+
# test inds which is not a tuple/list
1666+
fn(x, inds, 0, vals)
1667+
x2 = dpt.ones((5, 5))
1668+
ind1 = dpt.arange(5, dtype="i8")
1669+
ind2 = dpt.arange(5, dtype="u8")
1670+
with pytest.raises(ValueError):
1671+
fn(x2, (ind1, ind2), 0, x2)
1672+
with pytest.raises(TypeError):
1673+
fn(x2, (ind1, list()), 0, x2)
1674+
1675+
1676+
def test__copy_utils():
1677+
import dpctl.tensor._copy_utils as cu
1678+
1679+
get_queue_or_skip()
1680+
1681+
check__extract_impl_validation(cu._extract_impl)
1682+
check__nonzero_impl_validation(cu._nonzero_impl)
1683+
check__take_multi_index(cu._take_multi_index)
1684+
check__place_impl_validation(cu._place_impl)
1685+
check__put_multi_index_validation(cu._put_multi_index)

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,24 @@ def test_argsort_axis0():
177177
x = dpt.reshape(xf, (n, m))
178178
idx = dpt.argsort(x, axis=0)
179179

180-
conseq_idx = dpt.arange(m, dtype=idx.dtype)
181-
s = x[idx, conseq_idx[dpt.newaxis, :]]
180+
s = dpt.take_along_axis(x, idx, axis=0)
182181

183182
assert dpt.all(s[:-1, :] <= s[1:, :])
184183

185184

185+
def test_argsort_axis1():
186+
get_queue_or_skip()
187+
188+
n, m = 200, 30
189+
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
190+
x = dpt.reshape(xf, (n, m))
191+
idx = dpt.argsort(x, axis=1)
192+
193+
s = dpt.take_along_axis(x, idx, axis=1)
194+
195+
assert dpt.all(s[:, :-1] <= s[:, 1:])
196+
197+
186198
def test_sort_strided():
187199
get_queue_or_skip()
188200

@@ -199,8 +211,9 @@ def test_argsort_strided():
199211
x_orig = dpt.arange(100, dtype="i4")
200212
x_flipped = dpt.flip(x_orig, axis=0)
201213
idx = dpt.argsort(x_flipped)
214+
s = dpt.take_along_axis(x_flipped, idx, axis=0)
202215

203-
assert dpt.all(x_flipped[idx] == x_orig)
216+
assert dpt.all(s == x_orig)
204217

205218

206219
def test_sort_0d_array():

0 commit comments

Comments
 (0)