4444from dpnp .dpnp_utils import *
4545
4646import dpnp
47+ from dpnp .dpnp_array import dpnp_array
48+
4749import numpy
50+ import dpctl .tensor as dpt
4851
4952
5053__all__ = [
@@ -181,7 +184,7 @@ def where(condition, x=None, y=None, /):
181184 Return elements chosen from `x` or `y` depending on `condition`.
182185
183186 When only `condition` is provided, this function is a shorthand for
184- :obj:`dpnp.nonzero(condition)`.
187+ :obj:`dpnp.nonzero(condition)`.
185188
186189 For full documentation refer to :obj:`numpy.where`.
187190
@@ -193,12 +196,13 @@ def where(condition, x=None, y=None, /):
193196
194197 Limitations
195198 -----------
196- Parameters `condition`, `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
199+ Parameter `condition` is supported as either :class:`dpnp.ndarray`
197200 or :class:`dpctl.tensor.usm_ndarray`.
201+ Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
202+ or :class:`dpctl.tensor.usm_ndarray`
198203 Otherwise the function will be executed sequentially on CPU.
199- Data type of `condition` parameter is limited by :obj:`dpnp.bool`.
200204 Input array data types of `x` and `y` are limited by supported DPNP :ref:`Data types`.
201-
205+
202206 See Also
203207 --------
204208 :obj:`nonzero` : The function that is called when `x` and `y`are omitted.
@@ -220,18 +224,17 @@ def where(condition, x=None, y=None, /):
220224 elif missing == 2 :
221225 return dpnp .nonzero (condition )
222226 elif missing == 0 :
223- # get USM type and queue to copy scalar from the host memory into a USM allocation
224- usm_type , queue = get_usm_allocations ([condition , x , y ])
225-
226- c_desc = dpnp .get_dpnp_descriptor (condition , copy_when_strides = False , copy_when_nondefault_queue = False ,
227- alloc_usm_type = usm_type , alloc_queue = queue )
228- x_desc = dpnp .get_dpnp_descriptor (x , copy_when_strides = False , copy_when_nondefault_queue = False ,
229- alloc_usm_type = usm_type , alloc_queue = queue )
230- y_desc = dpnp .get_dpnp_descriptor (y , copy_when_strides = False , copy_when_nondefault_queue = False ,
231- alloc_usm_type = usm_type , alloc_queue = queue )
232- if c_desc and x_desc and y_desc :
233- if c_desc .dtype != dpnp .bool :
234- raise TypeError ("condition must be a boolean array" )
235- return dpnp_where (c_desc , x_desc , y_desc ).get_pyobj ()
227+ check_input_type = lambda x : isinstance (x , (dpnp_array , dpt .usm_ndarray ))
228+ if check_input_type (condition ):
229+ if numpy .isscalar (x ) or numpy .isscalar (y ):
230+ # get USM type and queue to copy scalar from the host memory into a USM allocation
231+ usm_type , queue = get_usm_allocations ([condition , x , y ])
232+ x = dpt .asarray (x , usm_type = usm_type , sycl_queue = queue ) if numpy .isscalar (x ) else x
233+ y = dpt .asarray (y , usm_type = usm_type , sycl_queue = queue ) if numpy .isscalar (y ) else y
234+ if check_input_type (x ) and check_input_type (y ):
235+ dpt_condition = condition .get_array () if isinstance (condition , dpnp_array ) else condition
236+ dpt_x = x .get_array () if isinstance (x , dpnp_array ) else x
237+ dpt_y = y .get_array () if isinstance (y , dpnp_array ) else y
238+ return dpnp_array ._create_from_usm_ndarray (dpt .where (dpt_condition , dpt_x , dpt_y ))
236239
237240 return call_origin (numpy .where , condition , x , y )
0 commit comments