Skip to content

Commit d1d67b9

Browse files
committed
Boolean reduction axis=None logic change
This case now circumvents the call to permute_dims completely Tests were updated to reflect this change and cover both branches Also added a test for the axis=() case
1 parent 400e437 commit d1d67b9

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

dpctl/tensor/_utility_functions.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,29 @@ def _boolean_reduction(x, axis, keepdims, func):
1111

1212
nd = x.ndim
1313
if axis is None:
14-
axis = tuple(range(nd))
15-
if not isinstance(axis, (tuple, list)):
16-
axis = (axis,)
17-
axis = normalize_axis_tuple(axis, nd, "axis")
14+
red_nd = nd
15+
# case of a scalar
16+
if red_nd == 0:
17+
return dpt.astype(x, dpt.bool)
18+
x_tmp = x
19+
res_shape = tuple()
20+
perm = list(range(nd))
21+
else:
22+
if not isinstance(axis, (tuple, list)):
23+
axis = (axis,)
24+
axis = normalize_axis_tuple(axis, nd, "axis")
25+
26+
red_nd = len(axis)
27+
# check for axis=()
28+
if red_nd == 0:
29+
return dpt.astype(x, dpt.bool)
30+
perm = [i for i in range(nd) if i not in axis] + list(axis)
31+
x_tmp = dpt.permute_dims(x, perm)
32+
res_shape = x_tmp.shape[: nd - red_nd]
1833

1934
exec_q = x.sycl_queue
2035
res_usm_type = x.usm_type
2136

22-
red_nd = len(axis)
23-
if red_nd == 0:
24-
return dpt.astype(x, dpt.bool)
25-
26-
perm = [i for i in range(nd) if i not in axis] + list(axis)
27-
x_tmp = dpt.permute_dims(x, perm)
28-
res_shape = x_tmp.shape[: nd - red_nd]
29-
3037
wait_list = []
3138
res_tmp = dpt.empty(
3239
res_shape,
@@ -59,7 +66,6 @@ def _boolean_reduction(x, axis, keepdims, func):
5966
inv_perm = sorted(range(nd), key=lambda d: perm[d])
6067
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
6168
dpctl.SyclEvent.wait_for(wait_list)
62-
6369
return res
6470

6571

dpctl/tests/test_usm_ndarray_utility_functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def test_boolean_reduction_keepdims(func):
9292
assert res.shape == (2, 1, 1, 5, 1)
9393
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True))
9494

95+
res = func(x, axis=None, keepdims=True)
96+
assert res.shape == (1,) * x.ndim
97+
9598

9699
@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)])
97100
def test_boolean_reduction_empty(func, identity):
@@ -119,7 +122,19 @@ def test_boolean_reduction_scalars(func):
119122
get_queue_or_skip()
120123

121124
x = dpt.ones((), dtype="i4")
122-
func(x)
125+
assert_equal(dpt.asnumpy(func(x)), True)
126+
127+
x = dpt.zeros((), dtype="i4")
128+
assert_equal(dpt.asnumpy(func(x)), False)
129+
130+
131+
@pytest.mark.parametrize("func", [dpt.all, dpt.any])
132+
def test_boolean_reduction_empty_axis(func):
133+
get_queue_or_skip()
134+
135+
x = dpt.ones((5,), dtype="i4")
136+
res = func(x, axis=())
137+
assert_array_equal(dpt.asnumpy(res), dpt.asnumpy(x).astype(np.bool_))
123138

124139

125140
@pytest.mark.parametrize("func", [dpt.all, dpt.any])

0 commit comments

Comments
 (0)