@@ -551,12 +551,12 @@ def extract(condition, a):
551
551
"""
552
552
553
553
usm_a = dpnp .get_usm_ndarray (a )
554
- if not dpnp .is_supported_array_type ( condition ):
555
- usm_cond = dpt . asarray (
556
- condition , usm_type = a . usm_type , sycl_queue = a . sycl_queue
557
- )
558
- else :
559
- usm_cond = dpnp . get_usm_ndarray ( condition )
554
+ usm_cond = dpnp .as_usm_ndarray (
555
+ condition ,
556
+ dtype = dpnp . bool ,
557
+ usm_type = usm_a . usm_type ,
558
+ sycl_queue = usm_a . sycl_queue ,
559
+ )
560
560
561
561
if usm_cond .size != usm_a .size :
562
562
usm_a = dpt .reshape (usm_a , - 1 )
@@ -1011,30 +1011,74 @@ def nonzero(a):
1011
1011
)
1012
1012
1013
1013
1014
- def place (x , mask , vals , / ):
1014
+ def place (a , mask , vals ):
1015
1015
"""
1016
1016
Change elements of an array based on conditional and input values.
1017
1017
1018
+ Similar to ``dpnp.copyto(a, vals, where=mask)``, the difference is that
1019
+ :obj:`dpnp.place` uses the first N elements of `vals`, where N is
1020
+ the number of ``True`` values in `mask`, while :obj:`dpnp.copyto` uses
1021
+ the elements where `mask` is ``True``.
1022
+
1023
+ Note that :obj:`dpnp.extract` does the exact opposite of :obj:`dpnp.place`.
1024
+
1018
1025
For full documentation refer to :obj:`numpy.place`.
1019
1026
1020
- Limitations
1021
- -----------
1022
- Parameters `x`, `mask` and `vals` are supported either as
1023
- :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
1024
- Otherwise the function will be executed sequentially on CPU.
1027
+ Parameters
1028
+ ----------
1029
+ a : {dpnp.ndarray, usm_ndarray}
1030
+ Array to put data into.
1031
+ mask : {array_like, scalar}
1032
+ Boolean mask array. Must have the same size as `a`.
1033
+ vals : {array_like, scalar}
1034
+ Values to put into `a`. Only the first N elements are used, where N is
1035
+ the number of ``True`` values in `mask`. If `vals` is smaller than N,
1036
+ it will be repeated, and if elements of `a` are to be masked, this
1037
+ sequence must be non-empty.
1038
+
1039
+ See Also
1040
+ --------
1041
+ :obj:`dpnp.copyto` : Copies values from one array to another.
1042
+ :obj:`dpnp.put` : Replaces specified elements of an array with given values.
1043
+ :obj:`dpnp.take` : Take elements from an array along an axis.
1044
+ :obj:`dpnp.extract` : Return the elements of an array that satisfy some
1045
+ condition.
1046
+
1047
+ Examples
1048
+ --------
1049
+ >>> import dpnp as np
1050
+ >>> a = np.arange(6).reshape(2, 3)
1051
+ >>> np.place(a, a > 2, [44, 55])
1052
+ >>> a
1053
+ array([[ 0, 1, 2],
1054
+ [44, 55, 44]])
1055
+
1025
1056
"""
1026
1057
1027
- if (
1028
- dpnp .is_supported_array_type (x )
1029
- and dpnp .is_supported_array_type (mask )
1030
- and dpnp .is_supported_array_type (vals )
1031
- ):
1032
- dpt_array = x .get_array () if isinstance (x , dpnp_array ) else x
1033
- dpt_mask = mask .get_array () if isinstance (mask , dpnp_array ) else mask
1034
- dpt_vals = vals .get_array () if isinstance (vals , dpnp_array ) else vals
1035
- return dpt .place (dpt_array , dpt_mask , dpt_vals )
1036
-
1037
- return call_origin (numpy .place , x , mask , vals , dpnp_inplace = True )
1058
+ usm_a = dpnp .get_usm_ndarray (a )
1059
+ usm_mask = dpnp .as_usm_ndarray (
1060
+ mask ,
1061
+ dtype = dpnp .bool ,
1062
+ usm_type = usm_a .usm_type ,
1063
+ sycl_queue = usm_a .sycl_queue ,
1064
+ )
1065
+ usm_vals = dpnp .as_usm_ndarray (
1066
+ vals ,
1067
+ dtype = usm_a .dtype ,
1068
+ usm_type = usm_a .usm_type ,
1069
+ sycl_queue = usm_a .sycl_queue ,
1070
+ )
1071
+
1072
+ if usm_vals .ndim != 1 :
1073
+ # dpt.place supports only 1-D array of values
1074
+ usm_vals = dpt .reshape (usm_vals , - 1 )
1075
+
1076
+ if usm_vals .dtype != usm_a .dtype :
1077
+ # dpt.place casts values to a.dtype with "unsafe" rule,
1078
+ # while numpy.place does that with "safe" casting rule
1079
+ usm_vals = dpt .astype (usm_vals , usm_a .dtype , casting = "safe" , copy = False )
1080
+
1081
+ dpt .place (usm_a , usm_mask , usm_vals )
1038
1082
1039
1083
1040
1084
def put (a , ind , v , / , * , axis = None , mode = "wrap" ):
0 commit comments