Skip to content

Commit 0c7f196

Browse files
committed
Add inplace support of divide (#1434)
1 parent 2b4f173 commit 0c7f196

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
BinaryElementwiseFunc
3636
)
3737
import dpctl.tensor._tensor_impl as ti
38+
import dpctl.tensor as dpt
39+
import dpctl
40+
41+
import numpy
3842

3943

4044
__all__ = [
@@ -125,12 +129,27 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
125129
return vmi._div(sycl_queue, src1, src2, dst, depends)
126130
return ti._divide(src1, src2, dst, sycl_queue, depends)
127131

132+
def _call_divide_inplace(lhs, rhs, sycl_queue, depends=[]):
133+
"""In place workaround until dpctl.tensor provides the functionality."""
134+
135+
# allocate temporary memory for out array
136+
out = dpt.empty_like(lhs, dtype=numpy.result_type((lhs.dtype, rhs.dtype)))
137+
138+
# call a general callback
139+
div_ht_, div_ev_ = _call_divide(lhs, rhs, out, sycl_queue, depends)
140+
141+
# store the result into left input array and return events
142+
cp_ht_, cp_ev_ = ti._copy_usm_ndarray_into_usm_ndarray(src=out, dst=lhs, sycl_queue=sycl_queue, depends=[div_ev_])
143+
dpctl.SyclEvent.wait_for([div_ht_])
144+
return (cp_ht_, cp_ev_)
145+
128146
# dpctl.tensor only works with usm_ndarray or scalar
129147
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
130148
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
131149
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
132150

133-
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
151+
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide,
152+
_divide_docstring_, _call_divide_inplace)
134153
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
135154
return dpnp_array._create_from_usm_ndarray(res_usm)
136155

@@ -208,6 +227,11 @@ def dpnp_subtract(x1, x2, out=None, order='K'):
208227
209228
"""
210229

230+
# TODO: discuss with dpctl if the check is needed to be moved there
231+
if not dpnp.isscalar(x1) and not dpnp.isscalar(x2) and x1.dtype == x2.dtype == dpnp.bool:
232+
raise TypeError("DPNP boolean subtract, the `-` operator, is not supported, "
233+
"use the bitwise_xor, the `^` operator, or the logical_xor function instead.")
234+
211235
# dpctl.tensor only works with usm_ndarray or scalar
212236
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
213237
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)

tests/test_usm_type.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def test_coerced_usm_types_sum(usm_type_x, usm_type_y):
1919

2020
z = 1.3 + x + y + 2
2121

22-
# TODO: unmute once dpctl support that
23-
# z += x
24-
# z += 7.4
22+
# inplace add
23+
z += x
24+
z += 7.4
2525

2626
assert x.usm_type == usm_type_x
2727
assert y.usm_type == usm_type_y
@@ -36,9 +36,9 @@ def test_coerced_usm_types_mul(usm_type_x, usm_type_y):
3636

3737
z = 3 * x * y * 1.5
3838

39-
# TODO: unmute once dpctl support that
40-
# z *= x
41-
# z *= 4.8
39+
# inplace multiply
40+
z *= x
41+
z *= 4.8
4242

4343
assert x.usm_type == usm_type_x
4444
assert y.usm_type == usm_type_y
@@ -53,6 +53,10 @@ def test_coerced_usm_types_subtract(usm_type_x, usm_type_y):
5353

5454
z = 20 - x - y - 7.4
5555

56+
# inplace subtract
57+
z -= x
58+
z -= -3.4
59+
5660
assert x.usm_type == usm_type_x
5761
assert y.usm_type == usm_type_y
5862
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
@@ -66,6 +70,10 @@ def test_coerced_usm_types_divide(usm_type_x, usm_type_y):
6670

6771
z = 2 / x / y / 1.5
6872

73+
# inplace divide
74+
z /= x
75+
z /= -2.4
76+
6977
assert x.usm_type == usm_type_x
7078
assert y.usm_type == usm_type_y
7179
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])

tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ def test_transposed_multidim_vdot(self, xp, dtype):
228228
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
229229
@testing.for_all_dtypes()
230230
@testing.numpy_cupy_allclose()
231-
@pytest.mark.skip("mute until dpctl support in-place add")
232231
def test_inner(self, xp, dtype):
233232
a = testing.shaped_arange((5,), xp, dtype)
234233
b = testing.shaped_reverse_arange((5,), xp, dtype)
@@ -237,7 +236,6 @@ def test_inner(self, xp, dtype):
237236
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
238237
@testing.for_all_dtypes()
239238
@testing.numpy_cupy_allclose()
240-
@pytest.mark.skip("mute until dpctl support in-place add")
241239
def test_reversed_inner(self, xp, dtype):
242240
a = testing.shaped_arange((5,), xp, dtype)[::-1]
243241
b = testing.shaped_reverse_arange((5,), xp, dtype)[::-1]
@@ -246,15 +244,13 @@ def test_reversed_inner(self, xp, dtype):
246244
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
247245
@testing.for_all_dtypes()
248246
@testing.numpy_cupy_allclose()
249-
@pytest.mark.skip("mute until dpctl support in-place add")
250247
def test_multidim_inner(self, xp, dtype):
251248
a = testing.shaped_arange((2, 3, 4), xp, dtype)
252249
b = testing.shaped_arange((3, 2, 4), xp, dtype)
253250
return xp.inner(a, b)
254251

255252
@testing.for_all_dtypes()
256253
@testing.numpy_cupy_allclose()
257-
@pytest.mark.skip("mute until dpctl support in-place add")
258254
def test_transposed_higher_order_inner(self, xp, dtype):
259255
a = testing.shaped_arange((2, 4, 3), xp, dtype).transpose(2, 0, 1)
260256
b = testing.shaped_arange((4, 2, 3), xp, dtype).transpose(1, 2, 0)

tests/third_party/cupy/math_tests/test_arithmetic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def test_modf(self, xp, dtype):
280280
'shape': [(3, 2), (), (3, 0, 2)]
281281
}))
282282
@testing.gpu
283-
@pytest.mark.skip("dpctl doesn't raise an error")
284283
class TestBoolSubtract(unittest.TestCase):
285284

286285
def test_bool_subtract(self):

0 commit comments

Comments
 (0)