Skip to content

Commit ded3f7d

Browse files
authored
Reuse add(), multiply() and subtract() from dpctl (#1430)
* Reuse add(), multiply() and subtract() from dpctl * add in-place support
1 parent 3e780ea commit ded3f7d

File tree

10 files changed

+247
-147
lines changed

10 files changed

+247
-147
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3636
cdef enum DPNPFuncName "DPNPFuncName":
3737
DPNP_FN_ABSOLUTE
3838
DPNP_FN_ABSOLUTE_EXT
39-
DPNP_FN_ADD
40-
DPNP_FN_ADD_EXT
4139
DPNP_FN_ALL
4240
DPNP_FN_ALL_EXT
4341
DPNP_FN_ALLCLOSE
@@ -117,7 +115,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
117115
DPNP_FN_DIAG_INDICES_EXT
118116
DPNP_FN_DIAGONAL
119117
DPNP_FN_DIAGONAL_EXT
120-
DPNP_FN_DIVIDE
121118
DPNP_FN_DOT
122119
DPNP_FN_DOT_EXT
123120
DPNP_FN_EDIFF1D
@@ -203,8 +200,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
203200
DPNP_FN_MINIMUM_EXT
204201
DPNP_FN_MODF
205202
DPNP_FN_MODF_EXT
206-
DPNP_FN_MULTIPLY
207-
DPNP_FN_MULTIPLY_EXT
208203
DPNP_FN_NANVAR
209204
DPNP_FN_NANVAR_EXT
210205
DPNP_FN_NEGATIVE
@@ -323,8 +318,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
323318
DPNP_FN_SQUARE_EXT
324319
DPNP_FN_STD
325320
DPNP_FN_STD_EXT
326-
DPNP_FN_SUBTRACT
327-
DPNP_FN_SUBTRACT_EXT
328321
DPNP_FN_SUM
329322
DPNP_FN_SUM_EXT
330323
DPNP_FN_SVD
@@ -523,8 +516,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
523516
"""
524517
Mathematical functions
525518
"""
526-
cpdef dpnp_descriptor dpnp_add(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
527-
dpnp_descriptor out=*, object where=*)
528519
cpdef dpnp_descriptor dpnp_arctan2(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
529520
dpnp_descriptor out=*, object where=*)
530521
cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
@@ -533,15 +524,11 @@ cpdef dpnp_descriptor dpnp_maximum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
533524
dpnp_descriptor out=*, object where=*)
534525
cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
535526
dpnp_descriptor out=*, object where=*)
536-
cpdef dpnp_descriptor dpnp_multiply(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
537-
dpnp_descriptor out=*, object where=*)
538527
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
539528
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
540529
dpnp_descriptor out=*, object where=*)
541530
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
542531
dpnp_descriptor out=*, object where=*)
543-
cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
544-
dpnp_descriptor out=*, object where=*)
545532

546533

547534
"""

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ and the rest of the library
3737

3838
__all__ += [
3939
"dpnp_absolute",
40-
"dpnp_add",
4140
"dpnp_arctan2",
4241
"dpnp_around",
4342
"dpnp_ceil",
@@ -57,7 +56,6 @@ __all__ += [
5756
"dpnp_maximum",
5857
"dpnp_minimum",
5958
"dpnp_modf",
60-
"dpnp_multiply",
6159
"dpnp_nancumprod",
6260
"dpnp_nancumsum",
6361
"dpnp_nanprod",
@@ -67,7 +65,6 @@ __all__ += [
6765
"dpnp_prod",
6866
"dpnp_remainder",
6967
"dpnp_sign",
70-
"dpnp_subtract",
7168
"dpnp_sum",
7269
"dpnp_trapz",
7370
"dpnp_trunc"
@@ -123,14 +120,6 @@ cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
123120
return result
124121

125122

126-
cpdef utils.dpnp_descriptor dpnp_add(utils.dpnp_descriptor x1_obj,
127-
utils.dpnp_descriptor x2_obj,
128-
object dtype=None,
129-
utils.dpnp_descriptor out=None,
130-
object where=True):
131-
return call_fptr_2in_1out_strides(DPNP_FN_ADD_EXT, x1_obj, x2_obj, dtype, out, where)
132-
133-
134123
cpdef utils.dpnp_descriptor dpnp_arctan2(utils.dpnp_descriptor x1_obj,
135124
utils.dpnp_descriptor x2_obj,
136125
object dtype=None,
@@ -426,14 +415,6 @@ cpdef tuple dpnp_modf(utils.dpnp_descriptor x1):
426415
return (result1.get_pyobj(), result2.get_pyobj())
427416

428417

429-
cpdef utils.dpnp_descriptor dpnp_multiply(utils.dpnp_descriptor x1_obj,
430-
utils.dpnp_descriptor x2_obj,
431-
object dtype=None,
432-
utils.dpnp_descriptor out=None,
433-
object where=True):
434-
return call_fptr_2in_1out_strides(DPNP_FN_MULTIPLY_EXT, x1_obj, x2_obj, dtype, out, where)
435-
436-
437418
cpdef utils.dpnp_descriptor dpnp_nancumprod(utils.dpnp_descriptor x1):
438419
cur_x1 = dpnp_copy(x1).get_pyobj()
439420

@@ -586,14 +567,6 @@ cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
586567
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
587568

588569

589-
cpdef utils.dpnp_descriptor dpnp_subtract(utils.dpnp_descriptor x1_obj,
590-
utils.dpnp_descriptor x2_obj,
591-
object dtype=None,
592-
utils.dpnp_descriptor out=None,
593-
object where=True):
594-
return call_fptr_2in_1out_strides(DPNP_FN_SUBTRACT_EXT, x1_obj, x2_obj, dtype, out, where)
595-
596-
597570
cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1,
598571
object axis=None,
599572
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,55 @@
3838

3939

4040
__all__ = [
41-
"dpnp_divide"
41+
"dpnp_add",
42+
"dpnp_divide",
43+
"dpnp_multiply",
44+
"dpnp_subtract"
4245
]
4346

4447

48+
_add_docstring_ = """
49+
add(x1, x2, out=None, order='K')
50+
51+
Calculates the sum for each element `x1_i` of the input array `x1` with
52+
the respective element `x2_i` of the input array `x2`.
53+
54+
Args:
55+
x1 (dpnp.ndarray):
56+
First input array, expected to have numeric data type.
57+
x2 (dpnp.ndarray):
58+
Second input array, also expected to have numeric data type.
59+
out ({None, dpnp.ndarray}, optional):
60+
Output array to populate.
61+
Array have the correct shape and the expected data type.
62+
order ("C","F","A","K", None, optional):
63+
Memory layout of the newly output array, if parameter `out` is `None`.
64+
Default: "K".
65+
Returns:
66+
dpnp.ndarray:
67+
an array containing the result of element-wise division. The data type
68+
of the returned array is determined by the Type Promotion Rules.
69+
"""
70+
71+
def dpnp_add(x1, x2, out=None, order='K'):
72+
"""
73+
Invokes add() from dpctl.tensor implementation for add() function.
74+
TODO: add a pybind11 extension of add() from OneMKL VM where possible
75+
and would be performance effective.
76+
77+
"""
78+
79+
# dpctl.tensor only works with usm_ndarray or scalar
80+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
81+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
82+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
83+
84+
func = BinaryElementwiseFunc("add", ti._add_result_type, ti._add,
85+
_add_docstring_, ti._add_inplace)
86+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
87+
return dpnp_array._create_from_usm_ndarray(res_usm)
88+
89+
4590
_divide_docstring_ = """
4691
divide(x1, x2, out=None, order='K')
4792
@@ -88,3 +133,87 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
88133
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
89134
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
90135
return dpnp_array._create_from_usm_ndarray(res_usm)
136+
137+
138+
_multiply_docstring_ = """
139+
multiply(x1, x2, out=None, order='K')
140+
141+
Calculates the product for each element `x1_i` of the input array `x1`
142+
with the respective element `x2_i` of the input array `x2`.
143+
144+
Args:
145+
x1 (dpnp.ndarray):
146+
First input array, expected to have numeric data type.
147+
x2 (dpnp.ndarray):
148+
Second input array, also expected to have numeric data type.
149+
out ({None, dpnp.ndarray}, optional):
150+
Output array to populate.
151+
Array have the correct shape and the expected data type.
152+
order ("C","F","A","K", None, optional):
153+
Memory layout of the newly output array, if parameter `out` is `None`.
154+
Default: "K".
155+
Returns:
156+
dpnp.ndarray:
157+
an array containing the result of element-wise division. The data type
158+
of the returned array is determined by the Type Promotion Rules.
159+
"""
160+
161+
def dpnp_multiply(x1, x2, out=None, order='K'):
162+
"""
163+
Invokes multiply() from dpctl.tensor implementation for multiply() function.
164+
TODO: add a pybind11 extension of mul() from OneMKL VM where possible
165+
and would be performance effective.
166+
167+
"""
168+
169+
# dpctl.tensor only works with usm_ndarray or scalar
170+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
171+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
172+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
173+
174+
func = BinaryElementwiseFunc("multiply", ti._multiply_result_type, ti._multiply,
175+
_multiply_docstring_, ti._multiply_inplace)
176+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
177+
return dpnp_array._create_from_usm_ndarray(res_usm)
178+
179+
180+
_subtract_docstring_ = """
181+
subtract(x1, x2, out=None, order='K')
182+
183+
Calculates the difference bewteen each element `x1_i` of the input
184+
array `x1` and the respective element `x2_i` of the input array `x2`.
185+
186+
Args:
187+
x1 (dpnp.ndarray):
188+
First input array, expected to have numeric data type.
189+
x2 (dpnp.ndarray):
190+
Second input array, also expected to have numeric data type.
191+
out ({None, dpnp.ndarray}, optional):
192+
Output array to populate.
193+
Array have the correct shape and the expected data type.
194+
order ("C","F","A","K", None, optional):
195+
Memory layout of the newly output array, if parameter `out` is `None`.
196+
Default: "K".
197+
Returns:
198+
dpnp.ndarray:
199+
an array containing the result of element-wise division. The data type
200+
of the returned array is determined by the Type Promotion Rules.
201+
"""
202+
203+
def dpnp_subtract(x1, x2, out=None, order='K'):
204+
"""
205+
Invokes subtract() from dpctl.tensor implementation for subtract() function.
206+
TODO: add a pybind11 extension of sub() from OneMKL VM where possible
207+
and would be performance effective.
208+
209+
"""
210+
211+
# dpctl.tensor only works with usm_ndarray or scalar
212+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
213+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
214+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
215+
216+
func = BinaryElementwiseFunc("subtract", ti._subtract_result_type, ti._subtract,
217+
_subtract_docstring_, ti._subtract_inplace)
218+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
219+
return dpnp_array._create_from_usm_ndarray(res_usm)

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,15 @@ def __irshift__(self, other):
249249
dpnp.right_shift(self, other, out=self)
250250
return self
251251

252-
# '__isub__',
252+
def __isub__(self, other):
253+
dpnp.subtract(self, other, out=self)
254+
return self
255+
253256
# '__iter__',
254-
# '__itruediv__',
257+
258+
def __itruediv__(self, other):
259+
dpnp.true_divide(self, other, out=self)
260+
return self
255261

256262
def __ixor__(self, other):
257263
dpnp.bitwise_xor(self, other, out=self)

0 commit comments

Comments
 (0)