Skip to content

Commit

Permalink
add tests for negative use cases to improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Nov 9, 2023
1 parent 14f043f commit e026c37
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dpctl.tensor as dpt
import numpy
import pytest
from numpy.testing import assert_allclose
Expand All @@ -7,63 +8,54 @@
from .helper import get_all_dtypes


@pytest.mark.parametrize("func", ["argmax", "argmin"])
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2])
@pytest.mark.parametrize("keepdims", [False, True])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_argmax_argmin(axis, keepdims, dtype):
def test_argmax_argmin(func, axis, keepdims, dtype):
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
ia = dpnp.array(a)

np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims)

assert dpnp_res.shape == np_res.shape
assert_allclose(dpnp_res, np_res)

np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims)
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)

assert dpnp_res.shape == np_res.shape
assert_allclose(dpnp_res, np_res)


@pytest.mark.parametrize("func", ["argmax", "argmin"])
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
@pytest.mark.parametrize("keepdims", [False, True])
def test_argmax_argmin_bool(axis, keepdims):
def test_argmax_argmin_bool(func, axis, keepdims):
a = numpy.arange(2, dtype=dpnp.bool)
a = numpy.tile(a, (2, 2))
ia = dpnp.array(a)

np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
dpnp_res = dpnp.argmax(ia, axis=axis, keepdims=keepdims)
np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)

assert dpnp_res.shape == np_res.shape
assert_allclose(dpnp_res, np_res)

np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
dpnp_res = dpnp.argmin(ia, axis=axis, keepdims=keepdims)

assert dpnp_res.shape == np_res.shape
assert_allclose(dpnp_res, np_res)


@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2])
@pytest.mark.parametrize("keepdims", [False, True])
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_argmax_argmin_out(axis, keepdims, dtype):
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
@pytest.mark.parametrize("func", ["argmax", "argmin"])
def test_argmax_argmin_out(func):
a = numpy.arange(6).reshape((2, 3))
ia = dpnp.array(a)

np_res = numpy.argmax(a, axis=axis, keepdims=keepdims)
np_res = getattr(numpy, func)(a, axis=0)
dpnp_res = dpnp.array(numpy.empty_like(np_res))
dpnp.argmax(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
assert_allclose(dpnp_res, np_res)

assert dpnp_res.shape == np_res.shape
dpnp_res = dpt.asarray(numpy.empty_like(np_res))
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
assert_allclose(dpnp_res, np_res)

np_res = numpy.argmin(a, axis=axis, keepdims=keepdims)
dpnp_res = dpnp.array(numpy.empty_like(np_res))
dpnp.argmin(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
dpnp_res = numpy.empty_like(np_res)
with pytest.raises(TypeError):
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)

assert dpnp_res.shape == np_res.shape
assert_allclose(dpnp_res, np_res)
dpnp_res = dpnp.array(numpy.empty((2, 3)))
with pytest.raises(ValueError):
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)

0 comments on commit e026c37

Please sign in to comment.