Skip to content

Slicing bug gh 1135 #1136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2022 Intel Corporation
# Copyright 2020-2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,11 @@
# limitations under the License.

import numbers
from cpython.buffer cimport PyObject_CheckBuffer


cdef bint _is_buffer(object o):
return PyObject_CheckBuffer(o)


cdef Py_ssize_t _slice_len(
Expand All @@ -36,14 +41,23 @@ cdef Py_ssize_t _slice_len(

cdef bint _is_integral(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, (int, numbers.Integral)):
return True
if isinstance(x, usm_ndarray):
if x.ndim > 0:
return False
if x.dtype.kind not in "ui":
return False
return True
if isinstance(x, bool):
return False
if isinstance(x, int):
return True
if _is_buffer(x):
mbuf = memoryview(x)
if mbuf.ndim == 0:
f = mbuf.format
return f in "bBhHiIlLqQ"
else:
return False
if callable(getattr(x, "__index__", None)):
try:
x.__index__()
Expand All @@ -53,6 +67,34 @@ cdef bint _is_integral(object x) except *:
return False


cdef bint _is_boolean(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if x.ndim > 0:
return False
if x.dtype.kind not in "b":
return False
return True
if isinstance(x, bool):
return True
if isinstance(x, int):
return False
if _is_buffer(x):
mbuf = memoryview(x)
if mbuf.ndim == 0:
f = mbuf.format
return f in "?"
else:
return False
if callable(getattr(x, "__bool__", None)):
try:
x.__bool__()
except (TypeError, ValueError):
return False
return True
return False


def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
"""
Give basic slicing index `ind` and array layout information produce
Expand Down Expand Up @@ -82,6 +124,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
_no_advanced_ind,
_no_advanced_pos
)
elif _is_boolean(ind):
if ind:
return ((1,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
else:
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
elif _is_integral(ind):
ind = ind.__index__()
if 0 <= ind < shape[0]:
Expand Down Expand Up @@ -117,6 +164,10 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
axes_referenced += 1
if array_streak_started:
array_streak_interrupted = True
elif _is_boolean(i):
newaxis_count += 1
if array_streak_started:
array_streak_interrupted = True
elif _is_integral(i):
explicit_index += 1
axes_referenced += 1
Expand All @@ -133,9 +184,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
"separated by basic slicing specs."
)
dt_k = i.dtype.kind
if dt_k == "b":
if dt_k == "b" and i.ndim > 0:
axes_referenced += i.ndim
elif dt_k in "ui":
elif dt_k in "ui" and i.ndim > 0:
axes_referenced += 1
else:
raise IndexError(
Expand Down Expand Up @@ -186,6 +237,9 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
if sh_i == 0:
is_empty = True
k = k_new
elif _is_boolean(ind_i):
new_shape.append(1 if ind_i else 0)
new_strides.append(0)
elif _is_integral(ind_i):
ind_i = ind_i.__index__()
if 0 <= ind_i < shape[k]:
Expand Down
26 changes: 26 additions & 0 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,32 @@ def test_integer_strided_indexing():
assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()


def test_TrueFalse_indexing():
get_queue_or_skip()
n0, n1 = 2, 3
x = dpt.ones((n0, n1))
for ind in [True, dpt.asarray(True)]:
y1 = x[ind]
assert y1.shape == (1, n0, n1)
assert y1._pointer == x._pointer
y2 = x[:, ind]
assert y2.shape == (n0, 1, n1)
assert y2._pointer == x._pointer
y3 = x[..., ind]
assert y3.shape == (n0, n1, 1)
assert y3._pointer == x._pointer
for ind in [False, dpt.asarray(False)]:
y1 = x[ind]
assert y1.shape == (0, n0, n1)
assert y1._pointer == x._pointer
y2 = x[:, ind]
assert y2.shape == (n0, 0, n1)
assert y2._pointer == x._pointer
y3 = x[..., ind]
assert y3.shape == (n0, n1, 0)
assert y3._pointer == x._pointer


@pytest.mark.parametrize(
"data_dt",
_all_dtypes,
Expand Down