Skip to content

Commit a1078c7

Browse files
committed
Integer indexing "wrap" mode now default
- For performance reasons, "wrap" now clips positive indices and wraps negative indices
1 parent 4e06ba9 commit a1078c7

File tree

4 files changed

+20
-87
lines changed

4 files changed

+20
-87
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@
2727

2828

2929
def _get_indexing_mode(name):
30-
modes = {"default": 0, "clip": 1, "wrap": 2}
30+
modes = {"wrap": 0, "clip": 1}
3131
try:
3232
return modes[name]
3333
except KeyError:
3434
raise ValueError(
35-
"`mode` must be `default`, `clip`, or `wrap`."
36-
"Got `{}`.".format(name)
35+
"`mode` must be `wrap` or `clip`." "Got `{}`.".format(name)
3736
)
3837

3938

40-
def take(x, indices, /, *, axis=None, mode="default"):
41-
"""take(x, indices, axis=None, mode="default")
39+
def take(x, indices, /, *, axis=None, mode="wrap"):
40+
"""take(x, indices, axis=None, mode="wrap")
4241
4342
Takes elements from array along a given axis.
4443
@@ -53,11 +52,10 @@ def take(x, indices, /, *, axis=None, mode="default"):
5352
Default: `None`.
5453
mode:
5554
How out-of-bounds indices will be handled.
56-
"default" - clamps indices to (-n <= i < n), then wraps
55+
"wrap" - clamps indices to (-n <= i < n), then wraps
5756
negative indices.
5857
"clip" - clips indices to (0 <= i < n)
59-
"wrap" - wraps both negative and positive indices.
60-
Default: `"default"`.
58+
Default: `"wrap"`.
6159
6260
Returns:
6361
out: usm_ndarray
@@ -122,8 +120,8 @@ def take(x, indices, /, *, axis=None, mode="default"):
122120
return res
123121

124122

125-
def put(x, indices, vals, /, *, axis=None, mode="default"):
126-
"""put(x, indices, vals, axis=None, mode="default")
123+
def put(x, indices, vals, /, *, axis=None, mode="wrap"):
124+
"""put(x, indices, vals, axis=None, mode="wrap")
127125
128126
Puts values of an array into another array
129127
along a given axis.
@@ -142,11 +140,10 @@ def put(x, indices, vals, /, *, axis=None, mode="default"):
142140
Default: `None`.
143141
mode:
144142
How out-of-bounds indices will be handled.
145-
"default" - clamps indices to (-n <= i < n), then wraps
143+
"wrap" - clamps indices to (-n <= i < n), then wraps
146144
negative indices.
147145
"clip" - clips indices to (0 <= i < n)
148-
"wrap" - wraps both negative and positive indices.
149-
Default: `"default"`.
146+
Default: `"wrap"`.
150147
"""
151148
if not isinstance(x, dpt.usm_ndarray):
152149
raise TypeError(

dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ namespace py = pybind11;
4646
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
4747
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;
4848

49-
class FancyIndex
49+
class WrapIndex
5050
{
5151
public:
52-
FancyIndex() = default;
52+
WrapIndex() = default;
5353

5454
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
5555
{
@@ -73,20 +73,6 @@ class ClipIndex
7373
}
7474
};
7575

76-
class WrapIndex
77-
{
78-
public:
79-
WrapIndex() = default;
80-
81-
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
82-
{
83-
max_item = std::max<py::ssize_t>(max_item, 1);
84-
ind = (ind < 0) ? (ind + max_item * ((-ind / max_item) + 1)) % max_item
85-
: ind % max_item;
86-
return;
87-
}
88-
};
89-
9076
template <typename ProjectorT, typename T, typename indT> class TakeFunctor
9177
{
9278
private:
@@ -361,22 +347,6 @@ sycl::event put_impl(sycl::queue q,
361347
return put_ev;
362348
}
363349

364-
template <typename fnT, typename T, typename indT> struct TakeFancyFactory
365-
{
366-
fnT get()
367-
{
368-
if constexpr (std::is_integral<indT>::value &&
369-
!std::is_same<indT, bool>::value) {
370-
fnT fn = take_impl<FancyIndex, T, indT>;
371-
return fn;
372-
}
373-
else {
374-
fnT fn = nullptr;
375-
return fn;
376-
}
377-
}
378-
};
379-
380350
template <typename fnT, typename T, typename indT> struct TakeWrapFactory
381351
{
382352
fnT get()
@@ -409,22 +379,6 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
409379
}
410380
};
411381

412-
template <typename fnT, typename T, typename indT> struct PutFancyFactory
413-
{
414-
fnT get()
415-
{
416-
if constexpr (std::is_integral<indT>::value &&
417-
!std::is_same<indT, bool>::value) {
418-
fnT fn = put_impl<FancyIndex, T, indT>;
419-
return fn;
420-
}
421-
else {
422-
fnT fn = nullptr;
423-
return fn;
424-
}
425-
}
426-
};
427-
428382
template <typename fnT, typename T, typename indT> struct PutWrapFactory
429383
{
430384
fnT get()

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@
3939

4040
#include "integer_advanced_indexing.hpp"
4141

42-
#define INDEXING_MODES 3
43-
#define FANCY_MODE 0
42+
#define INDEXING_MODES 2
43+
#define WRAP_MODE 0
4444
#define CLIP_MODE 1
45-
#define WRAP_MODE 2
4645

4746
namespace dpctl
4847
{
@@ -884,11 +883,6 @@ void init_advanced_indexing_dispatch_tables(void)
884883
{
885884
using namespace dpctl::tensor::detail;
886885

887-
using dpctl::tensor::kernels::indexing::TakeFancyFactory;
888-
DispatchTableBuilder<take_fn_ptr_t, TakeFancyFactory, num_types>
889-
dtb_takefancy;
890-
dtb_takefancy.populate_dispatch_table(take_dispatch_table[FANCY_MODE]);
891-
892886
using dpctl::tensor::kernels::indexing::TakeClipFactory;
893887
DispatchTableBuilder<take_fn_ptr_t, TakeClipFactory, num_types>
894888
dtb_takeclip;
@@ -899,10 +893,6 @@ void init_advanced_indexing_dispatch_tables(void)
899893
dtb_takewrap;
900894
dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]);
901895

902-
using dpctl::tensor::kernels::indexing::PutFancyFactory;
903-
DispatchTableBuilder<put_fn_ptr_t, PutFancyFactory, num_types> dtb_putfancy;
904-
dtb_putfancy.populate_dispatch_table(put_dispatch_table[FANCY_MODE]);
905-
906896
using dpctl::tensor::kernels::indexing::PutClipFactory;
907897
DispatchTableBuilder<put_fn_ptr_t, PutClipFactory, num_types> dtb_putclip;
908898
dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]);

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -898,27 +898,19 @@ def test_integer_indexing_modes():
898898
x = dpt.arange(5, sycl_queue=q)
899899
x_np = dpt.asnumpy(x)
900900

901-
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
902-
ind_np = dpt.asnumpy(ind)
901+
# wrapping negative indices
902+
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
903903

904-
# wrapping
905904
res = dpt.take(x, ind, mode="wrap")
906-
expected_arr = np.take(x_np, ind_np, mode="wrap")
905+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="raise")
907906

908907
assert (dpt.asnumpy(res) == expected_arr).all()
909908

910909
# clipping to 0 (disabling negative indices)
911-
res = dpt.take(x, ind, mode="clip")
912-
expected_arr = np.take(x_np, ind_np, mode="clip")
913-
914-
assert (dpt.asnumpy(res) == expected_arr).all()
915-
916-
# clipping to -n<=i<n,
917-
# where n is the axis length
918-
ind = dpt.asarray([-4, -3, 0, 2, 4], dtype=np.intp, sycl_queue=q)
910+
ind = dpt.asarray([-6, -3, 0, 2, 6], dtype=np.intp, sycl_queue=q)
919911

920-
res = dpt.take(x, ind, mode="default")
921-
expected_arr = np.take(dpt.asnumpy(x), dpt.asnumpy(ind), mode="raise")
912+
res = dpt.take(x, ind, mode="clip")
913+
expected_arr = np.take(x_np, dpt.asnumpy(ind), mode="clip")
922914

923915
assert (dpt.asnumpy(res) == expected_arr).all()
924916

0 commit comments

Comments
 (0)