Skip to content

Commit eccd026

Browse files
Closes gh-1135
Handle x[True] and x[False] as NumPy does, even though the behavior may be undocumented. NumPy treats True as None (insert axis with size 1), and treats False as None followed by empty slicing (insert axis with size 0). Changed the logic of _basic_slice_meta utility function to correctly handle boolean scalars (surprisingly, `insinstance(True, int)` evaluates to `True`). 0d arrays are handled by Python scalars. Introduced _is_integral and _is_boolean utilty functions and used them in `_basic_slice_meta` utility.
1 parent 18bb612 commit eccd026

File tree

1 file changed

+59
-5
lines changed

1 file changed

+59
-5
lines changed

dpctl/tensor/_slicing.pxi

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data Parallel Control (dpctl)
22
#
3-
# Copyright 2020-2022 Intel Corporation
3+
# Copyright 2020-2023 Intel Corporation
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -15,6 +15,11 @@
1515
# limitations under the License.
1616

1717
import numbers
18+
from cpython.buffer cimport PyObject_CheckBuffer
19+
20+
21+
cdef bint _is_buffer(object o):
22+
return PyObject_CheckBuffer(o)
1823

1924

2025
cdef Py_ssize_t _slice_len(
@@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len(
3641

3742
cdef bint _is_integral(object x) except *:
3843
"""Gives True if x is an integral slice spec"""
39-
if isinstance(x, (int, numbers.Integral)):
40-
return True
4144
if isinstance(x, usm_ndarray):
4245
if x.ndim > 0:
4346
return False
4447
if x.dtype.kind not in "ui":
4548
return False
4649
return True
50+
if isinstance(x, bool):
51+
return False
52+
if isinstance(x, int):
53+
return True
54+
if _is_buffer(x):
55+
mbuf = memoryview(x)
56+
if mbuf.ndim == 0:
57+
f = mbuf.format
58+
return f in "bBhHiIlLqQ"
59+
else:
60+
return False
4761
if callable(getattr(x, "__index__", None)):
4862
try:
4963
x.__index__()
@@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *:
5367
return False
5468

5569

70+
cdef bint _is_boolean(object x) except *:
71+
"""Gives True if x is an integral slice spec"""
72+
if isinstance(x, usm_ndarray):
73+
if x.ndim > 0:
74+
return False
75+
if x.dtype.kind not in "b":
76+
return False
77+
return True
78+
if isinstance(x, bool):
79+
return True
80+
if isinstance(x, int):
81+
return False
82+
if _is_buffer(x):
83+
mbuf = memoryview(x)
84+
if mbuf.ndim == 0:
85+
f = mbuf.format
86+
return f in "?"
87+
else:
88+
return False
89+
if callable(getattr(x, "__bool__", None)):
90+
try:
91+
x.__bool__()
92+
except (TypeError, ValueError):
93+
return False
94+
return True
95+
return False
96+
97+
5698
def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
5799
"""
58100
Give basic slicing index `ind` and array layout information produce
@@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
82124
_no_advanced_ind,
83125
_no_advanced_pos
84126
)
127+
elif _is_boolean(ind):
128+
if ind:
129+
return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
130+
else:
131+
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
85132
elif _is_integral(ind):
86133
ind = ind.__index__()
87134
if 0 <= ind < shape[0]:
@@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
117164
axes_referenced += 1
118165
if array_streak_started:
119166
array_streak_interrupted = True
167+
elif _is_boolean(i):
168+
newaxis_count += 1
169+
if array_streak_started:
170+
array_streak_interrupted = True
120171
elif _is_integral(i):
121172
explicit_index += 1
122173
axes_referenced += 1
@@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
133184
"separated by basic slicing specs."
134185
)
135186
dt_k = i.dtype.kind
136-
if dt_k == "b":
187+
if dt_k == "b" and i.ndim > 0:
137188
axes_referenced += i.ndim
138-
elif dt_k in "ui":
189+
elif dt_k in "ui" and i.ndim > 0:
139190
axes_referenced += 1
140191
else:
141192
raise IndexError(
@@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
186237
if sh_i == 0:
187238
is_empty = True
188239
k = k_new
240+
elif _is_boolean(ind_i):
241+
new_shape.append(1 if ind_i else 0)
242+
new_strides.append(0)
189243
elif _is_integral(ind_i):
190244
ind_i = ind_i.__index__()
191245
if 0 <= ind_i < shape[k]:

0 commit comments

Comments
 (0)