Skip to content

Commit 1f15ccc

Browse files
Add more tests for dpnp.sum and sum_over_axis_0 extension (#1488)
* Add more tests for dpnp.sum and sum_over_axis_0 extension * Add keepdims=True, bool and complex dtypes --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent 2f235bf commit 1f15ccc

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

tests/test_extensions.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,16 @@ def test_mean_over_axis_0_unsupported_out_types(
193193
input = dpt.empty((height, width), dtype=input_type, device=device)
194194
output = dpt.empty(width, dtype=output_type, device=device)
195195

196-
if func(input, output):
197-
print(output_type)
198196
assert func(input, output) is None
199197

200198

201199
@pytest.mark.parametrize(
202200
"func, device, input_type, output_type",
203201
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
204202
)
205-
def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
203+
def test_mean_sum_over_axis_0_f_contig_input(
204+
func, device, input_type, output_type
205+
):
206206
skip_unsupported(device, input_type)
207207
skip_unsupported(device, output_type)
208208

@@ -212,16 +212,14 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type):
212212
input = dpt.empty((height, width), dtype=input_type, device=device).T
213213
output = dpt.empty(width, dtype=output_type, device=device)
214214

215-
if func(input, output):
216-
print(output_type)
217215
assert func(input, output) is None
218216

219217

220218
@pytest.mark.parametrize(
221219
"func, device, input_type, output_type",
222220
product(mean_sum, all_devices, [dpt.float32], [dpt.float32]),
223221
)
224-
def test_mean_over_axis_0_f_contig_output(
222+
def test_mean_sum_over_axis_0_f_contig_output(
225223
func, device, input_type, output_type
226224
):
227225
skip_unsupported(device, input_type)
@@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output(
230228
height = 1
231229
width = 10
232230

233-
input = dpt.empty((height, 10), dtype=input_type, device=device)
234-
output = dpt.empty(20, dtype=output_type, device=device)[::2]
231+
input = dpt.empty((height, width), dtype=input_type, device=device)
232+
output = dpt.empty(width * 2, dtype=output_type, device=device)[::2]
233+
234+
assert func(input, output) is None
235+
236+
237+
@pytest.mark.parametrize(
238+
"func, device, input_type, output_type",
239+
product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]),
240+
)
241+
def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type):
242+
skip_unsupported(device, input_type)
243+
skip_unsupported(device, output_type)
244+
245+
local_mem_size = device.local_mem_size
246+
height = 1
247+
width = 1 + local_mem_size // output_type.itemsize
248+
249+
input = dpt.empty((height, width), dtype=input_type, device=device)
250+
output = dpt.empty(width, dtype=output_type, device=device)
235251

236-
if func(input, output):
237-
print(output_type)
238252
assert func(input, output) is None

tests/test_mathematical.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from itertools import permutations
2+
13
import numpy
24
import pytest
35
from numpy.testing import (
@@ -1056,23 +1058,42 @@ def test_sum_empty_out(dtype):
10561058

10571059

10581060
@pytest.mark.parametrize(
1059-
"shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)]
1060-
)
1061-
@pytest.mark.parametrize(
1062-
"dtype_in", get_all_dtypes(no_complex=True, no_bool=True)
1063-
)
1064-
@pytest.mark.parametrize(
1065-
"dtype_out", get_all_dtypes(no_complex=True, no_bool=True)
1061+
"shape",
1062+
[
1063+
(),
1064+
(1, 2, 3),
1065+
(1, 0, 2),
1066+
(10,),
1067+
(3, 3, 3),
1068+
(5, 5),
1069+
(0, 6),
1070+
(10, 1),
1071+
(1, 10),
1072+
],
10661073
)
1067-
def test_sum(shape, dtype_in, dtype_out):
1068-
a_np = numpy.ones(shape, dtype=dtype_in)
1069-
a = dpnp.ones(shape, dtype=dtype_in)
1070-
axes = [None, 0, 1, 2]
1074+
@pytest.mark.parametrize("dtype_in", get_all_dtypes())
1075+
@pytest.mark.parametrize("dtype_out", get_all_dtypes())
1076+
@pytest.mark.parametrize("transpose", [True, False])
1077+
@pytest.mark.parametrize("keepdims", [True, False])
1078+
def test_sum(shape, dtype_in, dtype_out, transpose, keepdims):
1079+
size = numpy.prod(shape)
1080+
a_np = numpy.arange(size).astype(dtype_in).reshape(shape)
1081+
a = dpnp.asarray(a_np)
1082+
1083+
if transpose:
1084+
a_np = a_np.T
1085+
a = a.T
1086+
1087+
axes_range = list(numpy.arange(len(shape)))
1088+
axes = [None]
1089+
axes += axes_range
1090+
axes += permutations(axes_range, 2)
1091+
axes.append(tuple(axes_range))
1092+
10711093
for axis in axes:
1072-
if axis is None or axis < a.ndim:
1073-
numpy_res = a_np.sum(axis=axis, dtype=dtype_out)
1074-
dpnp_res = a.sum(axis=axis, dtype=dtype_out)
1075-
assert_array_equal(numpy_res, dpnp_res.asnumpy())
1094+
numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1095+
dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims)
1096+
assert_array_equal(numpy_res, dpnp_res.asnumpy())
10761097

10771098

10781099
class TestMean:

0 commit comments

Comments
 (0)