Skip to content

Commit 0b6be45

Browse files
vlad-perevezentsevantonwolfy
authored andcommitted
Reuse dpctl.tensor.where in dpnp.where
1 parent 60f00ec commit 0b6be45

File tree

3 files changed

+39
-20
lines changed

3 files changed

+39
-20
lines changed

dpnp/dpnp_iface_searching.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@
4444
from dpnp.dpnp_utils import *
4545

4646
import dpnp
47+
from dpnp.dpnp_array import dpnp_array
48+
4749
import numpy
50+
import dpctl.tensor as dpt
4851

4952

5053
__all__ = [
@@ -181,7 +184,7 @@ def where(condition, x=None, y=None, /):
181184
Return elements chosen from `x` or `y` depending on `condition`.
182185
183186
When only `condition` is provided, this function is a shorthand for
184-
:obj:`dpnp.nonzero(condition)`.
187+
:obj:`dpnp.nonzero(condition)`.
185188
186189
For full documentation refer to :obj:`numpy.where`.
187190
@@ -193,12 +196,13 @@ def where(condition, x=None, y=None, /):
193196
194197
Limitations
195198
-----------
196-
Parameters `condition`, `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
199+
Parameter `condition` is supported as either :class:`dpnp.ndarray`
197200
or :class:`dpctl.tensor.usm_ndarray`.
201+
Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
202+
or :class:`dpctl.tensor.usm_ndarray`
198203
Otherwise the function will be executed sequentially on CPU.
199-
Data type of `condition` parameter is limited by :obj:`dpnp.bool`.
200204
Input array data types of `x` and `y` are limited by supported DPNP :ref:`Data types`.
201-
205+
202206
See Also
203207
--------
204208
:obj:`nonzero` : The function that is called when `x` and `y`are omitted.
@@ -220,18 +224,17 @@ def where(condition, x=None, y=None, /):
220224
elif missing == 2:
221225
return dpnp.nonzero(condition)
222226
elif missing == 0:
223-
# get USM type and queue to copy scalar from the host memory into a USM allocation
224-
usm_type, queue = get_usm_allocations([condition, x, y])
225-
226-
c_desc = dpnp.get_dpnp_descriptor(condition, copy_when_strides=False, copy_when_nondefault_queue=False,
227-
alloc_usm_type=usm_type, alloc_queue=queue)
228-
x_desc = dpnp.get_dpnp_descriptor(x, copy_when_strides=False, copy_when_nondefault_queue=False,
229-
alloc_usm_type=usm_type, alloc_queue=queue)
230-
y_desc = dpnp.get_dpnp_descriptor(y, copy_when_strides=False, copy_when_nondefault_queue=False,
231-
alloc_usm_type=usm_type, alloc_queue=queue)
232-
if c_desc and x_desc and y_desc:
233-
if c_desc.dtype != dpnp.bool:
234-
raise TypeError("condition must be a boolean array")
235-
return dpnp_where(c_desc, x_desc, y_desc).get_pyobj()
227+
check_input_type = lambda x: isinstance(x, (dpnp_array, dpt.usm_ndarray))
228+
if check_input_type(condition):
229+
if numpy.isscalar(x) or numpy.isscalar(y):
230+
# get USM type and queue to copy scalar from the host memory into a USM allocation
231+
usm_type, queue = get_usm_allocations([condition, x, y])
232+
x = dpt.asarray(x, usm_type=usm_type, sycl_queue=queue) if numpy.isscalar(x) else x
233+
y = dpt.asarray(y, usm_type=usm_type, sycl_queue=queue) if numpy.isscalar(y) else y
234+
if check_input_type(x) and check_input_type(y):
235+
dpt_condition = condition.get_array() if isinstance(condition, dpnp_array) else condition
236+
dpt_x = x.get_array() if isinstance(x, dpnp_array) else x
237+
dpt_y = y.get_array() if isinstance(y, dpnp_array) else y
238+
return dpnp_array._create_from_usm_ndarray(dpt.where(dpt_condition, dpt_x, dpt_y))
236239

237240
return call_origin(numpy.where, condition, x, y)

tests/test_indexing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,22 @@ def test_triu_indices_from(array, k):
581581
result = dpnp.triu_indices_from(ia, k)
582582
expected = numpy.triu_indices_from(a, k)
583583
assert_array_equal(expected, result)
584+
585+
586+
@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
587+
@pytest.mark.parametrize("scalar_dtype", get_all_dtypes(no_none=True))
588+
def test_where_with_scalars(cond_dtype, scalar_dtype):
589+
a = numpy.array([-1, 0, 1, 0], dtype=cond_dtype)
590+
ia = dpnp.array(a)
591+
592+
result = dpnp.where(ia, scalar_dtype(1), scalar_dtype(0))
593+
expected = numpy.where(a, scalar_dtype(1), scalar_dtype(0))
594+
assert_array_equal(expected, result)
595+
596+
result = dpnp.where(ia, ia*2, scalar_dtype(0))
597+
expected = numpy.where(a, a*2, scalar_dtype(0))
598+
assert_array_equal(expected, result)
599+
600+
result = dpnp.where(ia, scalar_dtype(1), dpnp.array(0))
601+
expected = numpy.where(a, scalar_dtype(1), numpy.array(0))
602+
assert_array_equal(expected, result)

tests/third_party/cupy/sorting_tests/test_search.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def test_argminmax_dtype(self, in_dtype, result_dtype):
262262
{'cond_shape': (2, 3, 4), 'x_shape': (2, 3, 4), 'y_shape': (3, 4)},
263263
{'cond_shape': (3, 4), 'x_shape': (2, 3, 4), 'y_shape': (4,)},
264264
)
265-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
266265
@testing.gpu
267266
class TestWhereTwoArrays(unittest.TestCase):
268267

@@ -274,8 +273,6 @@ def test_where_two_arrays(self, xp, cond_type, x_type, y_type):
274273
# Almost all values of a matrix `shaped_random` makes are not zero.
275274
# To make a sparse matrix, we need multiply `m`.
276275
cond = testing.shaped_random(self.cond_shape, xp, cond_type) * m
277-
if xp is cupy:
278-
cond = cond.astype(cupy.bool)
279276
x = testing.shaped_random(self.x_shape, xp, x_type, seed=0)
280277
y = testing.shaped_random(self.y_shape, xp, y_type, seed=1)
281278
return xp.where(cond, x, y)

0 commit comments

Comments
 (0)