Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of dpnp.extract() #1340

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"diag_indices",
"diag_indices_from",
"diagonal",
"extract",
"fill_diagonal",
"indices",
"nonzero",
Expand Down Expand Up @@ -232,6 +233,40 @@ def diagonal(x1, offset=0, axis1=0, axis2=1):
return call_origin(numpy.diagonal, x1, offset, axis1, axis2)


def extract(condition, x):
"""
Return the elements of an array that satisfy some condition.
For full documentation refer to :obj:`numpy.extract`.

antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
y : dpnp.ndarray
Rank 1 array of values from `x` where `condition` is True.

Limitations
-----------
Parameters `condition` and `x` are supported either as
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Parameter `x` must be the same shape as `condition`.
Otherwise the function will be executed sequentially on CPU.
"""

check_input_type = lambda x: isinstance(x, (dpnp_array, dpt.usm_ndarray))
if check_input_type(condition) and check_input_type(x):
if condition.shape != x.shape:
pass
else:
dpt_condition = (
condition.get_array()
if isinstance(condition, dpnp_array)
else condition
)
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
return dpnp_array._create_from_usm_ndarray(dpt.extract(dpt_condition, dpt_array))

return call_origin(numpy.extract, condition, x)


def fill_diagonal(x1, val, wrap=False):
"""
Fill the main diagonal of the given array of any dimensionality.
Expand Down Expand Up @@ -296,7 +331,7 @@ def nonzero(x, /):
-------
y : tuple[dpnp.ndarray]
Indices of elements that are non-zero.

Limitations
-----------
Parameters `x` is supported as either :class:`dpnp.ndarray`
Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tests.third_party.cupy import testing as cupy_testing
from .helper import has_support_aspect64
import dpnp
import numpy

Expand All @@ -17,6 +18,8 @@


def _shaped_arange(shape, xp=dpnp, dtype=dpnp.float64, order='C'):
if dtype is dpnp.float64:
dtype = dpnp.float32 if not has_support_aspect64() else dtype
res = xp.array(orig_shaped_arange(shape, xp=numpy, dtype=dtype, order=order), dtype=dtype)
return res

Expand Down
9 changes: 9 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,12 @@ def is_win_platform():
Return True if a test is runing on Windows OS, False otherwise.
"""
return platform.startswith('win')


def has_support_aspect64(device=None):
"""
Return True if the device supports 64-bit precision floating point operations,
False otherwise.
"""
dev = dpctl.select_default_device() if device is None else device
return dev.has_aspect_fp64
6 changes: 0 additions & 6 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_shape_mismatch
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
Expand Down
6 changes: 0 additions & 6 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_empty_1dim
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_shape_mismatch
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_extract_size_mismatch2
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
Expand Down
14 changes: 14 additions & 0 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from .helper import get_all_dtypes


import dpnp

Expand Down Expand Up @@ -53,6 +55,18 @@ def test_diagonal(array, offset):
assert_array_equal(expected, result)


@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
def test_extract_1d(arr_dtype, cond_dtype):
a = numpy.array([-2, -1, 0, 1, 2, 3], dtype=arr_dtype)
ia = dpnp.array(a)
cond = numpy.array([1, -1, 2, 0, -2, 3], dtype=cond_dtype)
icond = dpnp.array(cond)
expected = numpy.extract(cond, a)
result = dpnp.extract(icond, ia)
assert_array_equal(expected, result)


@pytest.mark.parametrize("val",
[-1, 0, 1],
ids=['-1', '0', '1'])
Expand Down
4 changes: 4 additions & 0 deletions tests/third_party/cupy/indexing_tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_extract_no_bool(self, xp, dtype):
b = xp.array([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=dtype)
return xp.extract(b, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.numpy_cupy_array_equal()
def test_extract_shape_mismatch(self, xp):
a = testing.shaped_arange((2, 3), xp)
Expand All @@ -174,20 +175,23 @@ def test_extract_shape_mismatch(self, xp):
[True, False]])
return xp.extract(b, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.numpy_cupy_array_equal()
def test_extract_size_mismatch(self, xp):
a = testing.shaped_arange((3, 3), xp)
b = xp.array([[True, False, True],
[False, True, False]])
return xp.extract(b, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.numpy_cupy_array_equal()
def test_extract_size_mismatch2(self, xp):
a = testing.shaped_arange((3, 3), xp)
b = xp.array([[True, False, True, False],
[False, True, False, True]])
return xp.extract(b, a)

@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.numpy_cupy_array_equal()
def test_extract_empty_1dim(self, xp):
a = testing.shaped_arange((3, 3), xp)
Expand Down
11 changes: 9 additions & 2 deletions tests/third_party/cupy/testing/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# from dpnp.core import internal
from tests.third_party.cupy.testing import array
from tests.third_party.cupy.testing import parameterized
from dpctl import select_default_device
# import dpnp
# import dpnp.scipy.sparse

Expand Down Expand Up @@ -654,9 +655,15 @@ def test_func(self, *args, **kw):
return test_func
return decorator

def _get_supported_float_dtypes():
if select_default_device().has_aspect_fp64:
return (numpy.float64, numpy.float32)
else:
return (numpy.float32,)


_complex_dtypes = ()
_regular_float_dtypes = (numpy.float64, numpy.float32)
_regular_float_dtypes = _get_supported_float_dtypes()
_float_dtypes = _regular_float_dtypes
_signed_dtypes = ()
_unsigned_dtypes = tuple(numpy.dtype(i).type for i in 'BHILQ')
Expand All @@ -667,7 +674,7 @@ def test_func(self, *args, **kw):


def _make_all_dtypes(no_float16, no_bool, no_complex):
return (numpy.float64, numpy.float32, numpy.int64, numpy.int32)
return (numpy.int64, numpy.int32) + _get_supported_float_dtypes()
# if no_float16:
# dtypes = _regular_float_dtypes
# else:
Expand Down