diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 410f5542f7b..5f631a839e2 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -212,6 +212,14 @@ def dpnp_add(x1, x2, out=None, order="K"): """ +bitwise_and_func = BinaryElementwiseFunc( + "bitwise_and", + ti._bitwise_and_result_type, + ti._bitwise_and, + _bitwise_and_docstring_, +) + + def dpnp_bitwise_and(x1, x2, out=None, order="K"): """Invokes bitwise_and() from dpctl.tensor implementation for bitwise_and() function.""" @@ -220,13 +228,9 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_and", - ti._bitwise_and_result_type, - ti._bitwise_and, - _bitwise_and_docstring_, + res_usm = bitwise_and_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -256,6 +260,14 @@ def dpnp_bitwise_and(x1, x2, out=None, order="K"): """ +bitwise_or_func = BinaryElementwiseFunc( + "bitwise_or", + ti._bitwise_or_result_type, + ti._bitwise_or, + _bitwise_or_docstring_, +) + + def dpnp_bitwise_or(x1, x2, out=None, order="K"): """Invokes bitwise_or() from dpctl.tensor implementation for bitwise_or() function.""" @@ -264,13 +276,9 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_or", - ti._bitwise_or_result_type, - ti._bitwise_or, - _bitwise_or_docstring_, + res_usm = bitwise_or_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -300,6 +308,14 @@ def dpnp_bitwise_or(x1, x2, out=None, order="K"): """ +bitwise_xor_func = BinaryElementwiseFunc( + "bitwise_xor", + ti._bitwise_xor_result_type, + ti._bitwise_xor, + _bitwise_xor_docstring_, +) + + def dpnp_bitwise_xor(x1, x2, out=None, order="K"): """Invokes bitwise_xor() from dpctl.tensor implementation for bitwise_xor() function.""" @@ -308,13 +324,9 @@ def dpnp_bitwise_xor(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_xor", - ti._bitwise_xor_result_type, - ti._bitwise_xor, - _bitwise_xor_docstring_, + res_usm = bitwise_xor_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -629,6 +641,14 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"): """ +invert_func = UnaryElementwiseFunc( + "invert", + ti._bitwise_invert_result_type, + ti._bitwise_invert, + _invert_docstring, +) + + def dpnp_invert(x, out=None, order="K"): """Invokes bitwise_invert() from dpctl.tensor implementation for invert() function.""" @@ -636,13 +656,7 @@ def dpnp_invert(x, out=None, order="K"): x_usm = dpnp.get_usm_ndarray(x) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = UnaryElementwiseFunc( - "invert", - ti._bitwise_invert_result_type, - ti._bitwise_invert, - _invert_docstring, - ) - res_usm = func(x_usm, out=out_usm, order=order) + res_usm = invert_func(x_usm, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -778,6 +792,14 @@ def dpnp_isnan(x, out=None, order="K"): """ +left_shift_func = BinaryElementwiseFunc( + "bitwise_leftt_shift", + ti._bitwise_left_shift_result_type, + ti._bitwise_left_shift, + _left_shift_docstring_, +) + + def dpnp_left_shift(x1, x2, out=None, order="K"): """Invokes bitwise_left_shift() from dpctl.tensor implementation for left_shift() function.""" @@ -786,13 +808,9 @@ def dpnp_left_shift(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_leftt_shift", - ti._bitwise_left_shift_result_type, - ti._bitwise_left_shift, - _left_shift_docstring_, + res_usm = left_shift_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) @@ -1199,6 +1217,14 @@ def dpnp_not_equal(x1, x2, out=None, order="K"): """ +right_shift_func = BinaryElementwiseFunc( + "bitwise_right_shift", + ti._bitwise_right_shift_result_type, + ti._bitwise_right_shift, + _right_shift_docstring_, +) + + def dpnp_right_shift(x1, x2, out=None, order="K"): """Invokes bitwise_right_shift() from dpctl.tensor implementation for right_shift() function.""" @@ -1207,13 +1233,9 @@ def dpnp_right_shift(x1, x2, out=None, order="K"): x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2) out_usm = None if out is None else dpnp.get_usm_ndarray(out) - func = BinaryElementwiseFunc( - "bitwise_right_shift", - ti._bitwise_right_shift_result_type, - ti._bitwise_right_shift, - _right_shift_docstring_, + res_usm = right_shift_func( + x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order ) - res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order) return dpnp_array._create_from_usm_ndarray(res_usm) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index a020be77637..c0e1ab3ea77 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -193,8 +193,6 @@ def test_mean_over_axis_0_unsupported_out_types( input = dpt.empty((height, width), dtype=input_type, device=device) output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None @@ -202,7 +200,9 @@ def test_mean_over_axis_0_unsupported_out_types( "func, device, input_type, output_type", product(mean_sum, all_devices, [dpt.float32], [dpt.float32]), ) -def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): +def test_mean_sum_over_axis_0_f_contig_input( + func, device, input_type, output_type +): skip_unsupported(device, input_type) skip_unsupported(device, output_type) @@ -212,8 +212,6 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): input = dpt.empty((height, width), dtype=input_type, device=device).T output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None @@ -221,7 +219,7 @@ def test_mean_over_axis_0_f_contig_input(func, device, input_type, output_type): "func, device, input_type, output_type", product(mean_sum, all_devices, [dpt.float32], [dpt.float32]), ) -def test_mean_over_axis_0_f_contig_output( +def test_mean_sum_over_axis_0_f_contig_output( func, device, input_type, output_type ): skip_unsupported(device, input_type) @@ -230,9 +228,25 @@ def test_mean_over_axis_0_f_contig_output( height = 1 width = 10 - input = dpt.empty((height, 10), dtype=input_type, device=device) - output = dpt.empty(20, dtype=output_type, device=device)[::2] + input = dpt.empty((height, width), dtype=input_type, device=device) + output = dpt.empty(width * 2, dtype=output_type, device=device)[::2] + + assert func(input, output) is None + + +@pytest.mark.parametrize( + "func, device, input_type, output_type", + product(mean_sum, all_devices, [dpt.float32], [dpt.float32, dpt.float64]), +) +def test_mean_sum_over_axis_0_big_output(func, device, input_type, output_type): + skip_unsupported(device, input_type) + skip_unsupported(device, output_type) + + local_mem_size = device.local_mem_size + height = 1 + width = 1 + local_mem_size // output_type.itemsize + + input = dpt.empty((height, width), dtype=input_type, device=device) + output = dpt.empty(width, dtype=output_type, device=device) - if func(input, output): - print(output_type) assert func(input, output) is None diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index fed1928e076..e0a16869567 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1,3 +1,5 @@ +from itertools import permutations + import numpy import pytest from numpy.testing import ( @@ -1056,23 +1058,42 @@ def test_sum_empty_out(dtype): @pytest.mark.parametrize( - "shape", [(), (1, 2, 3), (1, 0, 2), (10), (3, 3, 3), (5, 5), (0, 6)] -) -@pytest.mark.parametrize( - "dtype_in", get_all_dtypes(no_complex=True, no_bool=True) -) -@pytest.mark.parametrize( - "dtype_out", get_all_dtypes(no_complex=True, no_bool=True) + "shape", + [ + (), + (1, 2, 3), + (1, 0, 2), + (10,), + (3, 3, 3), + (5, 5), + (0, 6), + (10, 1), + (1, 10), + ], ) -def test_sum(shape, dtype_in, dtype_out): - a_np = numpy.ones(shape, dtype=dtype_in) - a = dpnp.ones(shape, dtype=dtype_in) - axes = [None, 0, 1, 2] +@pytest.mark.parametrize("dtype_in", get_all_dtypes()) +@pytest.mark.parametrize("dtype_out", get_all_dtypes()) +@pytest.mark.parametrize("transpose", [True, False]) +@pytest.mark.parametrize("keepdims", [True, False]) +def test_sum(shape, dtype_in, dtype_out, transpose, keepdims): + size = numpy.prod(shape) + a_np = numpy.arange(size).astype(dtype_in).reshape(shape) + a = dpnp.asarray(a_np) + + if transpose: + a_np = a_np.T + a = a.T + + axes_range = list(numpy.arange(len(shape))) + axes = [None] + axes += axes_range + axes += permutations(axes_range, 2) + axes.append(tuple(axes_range)) + for axis in axes: - if axis is None or axis < a.ndim: - numpy_res = a_np.sum(axis=axis, dtype=dtype_out) - dpnp_res = a.sum(axis=axis, dtype=dtype_out) - assert_array_equal(numpy_res, dpnp_res.asnumpy()) + numpy_res = a_np.sum(axis=axis, dtype=dtype_out, keepdims=keepdims) + dpnp_res = a.sum(axis=axis, dtype=dtype_out, keepdims=keepdims) + assert_array_equal(numpy_res, dpnp_res.asnumpy()) class TestMean: