From 1f15ccc36a35e9b3dc1b2a90ad928a151d113203 Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Mon, 14 Aug 2023 09:50:45 +0200 Subject: [PATCH] Add more tests for dpnp.sum and sum_over_axis_0 extension (#1488) * Add more tests for dpnp.sum and sum_over_axis_0 extension * Add keepdims=True, bool and complex dtypes --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- tests/test_extensions.py | 34 +++++++++++++++++-------- tests/test_mathematical.py | 51 +++++++++++++++++++++++++++----------- 2 files changed, 60 insertions(+), 25 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index a020be77637..c0e1ab3ea77 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -193,8 +193,6 @@ def test_mean_over_axis_0_unsupported_out_types( input = dpt.empty((height, width), dtype=input_type, device=device) output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None @@ -202,7 +200,9 @@ def test_mean_over_axis_0_unsupported_out_types( "func, device, input_type, output_type", product(mean_sum, all_devices, [dpt.float32], [dpt.float32]), ) -def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): +def test_mean_sum_over_axis_0_f_contig_input( + func, device, input_type, output_type +): skip_unsupported(device, input_type) skip_unsupported(device, output_type) @@ -212,8 +212,6 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): input = dpt.empty((height, width), dtype=input_type, device=device).T output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None @@ -221,7 +219,7 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): "func, device, input_type, output_type", product(mean_sum, all_devices, [dpt.float32], [dpt.float32]), ) -def test_mean_over_axis_0_f_contig_output( +def test_mean_sum_over_axis_0_f_contig_output( func, device, input_type, output_type ): skip_unsupported(device, input_type) @@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output( height = 1 width = 10 - input = dpt.empty((height, 10), dtype=input_type, device=device) - output = dpt.empty(20, dtype=output_type, device=device)[::2] + input = dpt.empty((height, width), dtype=input_type, device=device) + output = dpt.empty(width * 2, dtype=output_type, device=device)[::2] + + assert func(input, output) is None + + +@pytest.mark.parametrize( + "func, device, input_type, output_type", + product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]), +) +def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type): + skip_unsupported(device, input_type) + skip_unsupported(device, output_type) + + local_mem_size = device.local_mem_size + height = 1 + width = 1 + local_mem_size // output_type.itemsize + + input = dpt.empty((height, width), dtype=input_type, device=device) + output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index fed1928e076..e0a16869567 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1,3 +1,5 @@ +from itertools import permutations + import numpy import pytest from numpy.testing import ( @@ -1056,23 +1058,42 @@ def test_sum_empty_out(dtype): @pytest.mark.parametrize( - "shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)] -) -@pytest.mark.parametrize( - "dtype_in", get_all_dtypes(no_complex=True, no_bool=True) -) -@pytest.mark.parametrize( - "dtype_out", get_all_dtypes(no_complex=True, no_bool=True) + "shape", + [ + (), + (1, 2, 3), + (1, 0, 2), + (10,), + (3, 3, 3), + (5, 5), + (0, 6), + (10, 1), + (1, 10), + ], ) -def test_sum(shape, dtype_in, dtype_out): - a_np = numpy.ones(shape, dtype=dtype_in) - a = dpnp.ones(shape, dtype=dtype_in) - axes = [None, 0, 1, 2] +@pytest.mark.parametrize("dtype_in", get_all_dtypes()) +@pytest.mark.parametrize("dtype_out", get_all_dtypes()) +@pytest.mark.parametrize("transpose", [True, False]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_sum(shape, dtype_in, dtype_out, transpose, keepdims): + size = numpy.prod(shape) + a_np = numpy.arange(size).astype(dtype_in).reshape(shape) + a = dpnp.asarray(a_np) + + if transpose: + a_np = a_np.T + a = a.T + + axes_range = list(numpy.arange(len(shape))) + axes = [None] + axes += axes_range + axes += permutations(axes_range, 2) + axes.append(tuple(axes_range)) + for axis in axes: - if axis is None or axis < a.ndim: - numpy_res = a_np.sum(axis=axis, dtype=dtype_out) - dpnp_res = a.sum(axis=axis, dtype=dtype_out) - assert_array_equal(numpy_res, dpnp_res.asnumpy()) + numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims) + dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims) + assert_array_equal(numpy_res, dpnp_res.asnumpy()) class TestMean: