Skip to content

Commit cf4049a

Browse files
committed
In-place operations enabled in standard binary operators
- functionality such as binop(x, y, out=x) now possible, with some edge cases still WIP
1 parent e07d5f0 commit cf4049a

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def __repr__(self):
354354
return f"<BinaryElementwiseFunc '{self.name_}'>"
355355

356356
def __call__(self, o1, o2, out=None, order="K"):
357+
# FIXME: replace with check against base array
358+
# when views can be identified
359+
if o1 is out:
360+
return self._inplace(o1, o2)
361+
elif o2 is out:
362+
return self._inplace(o2, o1)
363+
357364
if order not in ["K", "C", "F", "A"]:
358365
order = "K"
359366
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -397,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
397404
raise TypeError(
398405
"Shape of arguments can not be inferred. "
399406
"Arguments are expected to be "
407+
"lists, tuples, or both"
400408
)
401409
try:
402410
res_shape = _broadcast_shape_impl(
@@ -424,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):
424432

425433
if res_dt is None:
426434
raise TypeError(
427-
"function 'add' does not support input types "
435+
f"function '{self.name_}' does not support input types "
428436
f"({o1_dtype}, {o2_dtype}), "
429437
"and the inputs could not be safely coerced to any "
430438
"supported types according to the casting rule ''safe''."
@@ -641,7 +649,7 @@ def __call__(self, o1, o2, out=None, order="K"):
641649
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
642650
return out
643651

644-
def inplace(self, lhs, val):
652+
def _inplace(self, lhs, val):
645653
if self.binary_inplace_fn_ is None:
646654
raise ValueError(
647655
f"In-place operation not supported for ufunc '{self.name_}'"
@@ -681,6 +689,7 @@ def inplace(self, lhs, val):
681689
raise TypeError(
682690
"Shape of arguments can not be inferred. "
683691
"Arguments are expected to be "
692+
"lists, tuples, or both"
684693
)
685694
try:
686695
res_shape = _broadcast_shape_impl(
@@ -715,7 +724,7 @@ def inplace(self, lhs, val):
715724

716725
if buf_dt is None:
717726
raise TypeError(
718-
f"function '{self.name_}' does not support input types "
727+
f"In-place '{self.name_}' does not support input types "
719728
f"({lhs_dtype}, {val_dtype}), "
720729
"and the inputs could not be safely coerced to any "
721730
"supported types according to the casting rule ''safe''."

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ cdef class usm_ndarray:
12461246

12471247
def __iadd__(self, other):
12481248
from ._elementwise_funcs import add
1249-
return add.inplace(self, other)
1249+
return add._inplace(self, other)
12501250

12511251
def __iand__(self, other):
12521252
res = self.__and__(other)
@@ -1285,7 +1285,7 @@ cdef class usm_ndarray:
12851285

12861286
def __imul__(self, other):
12871287
from ._elementwise_funcs import multiply
1288-
return multiply.inplace(self, other)
1288+
return multiply._inplace(self, other)
12891289

12901290
def __ior__(self, other):
12911291
res = self.__or__(other)
@@ -1310,7 +1310,7 @@ cdef class usm_ndarray:
13101310

13111311
def __isub__(self, other):
13121312
from ._elementwise_funcs import subtract
1313-
return subtract.inplace(self, other)
1313+
return subtract._inplace(self, other)
13141314

13151315
def __itruediv__(self, other):
13161316
res = self.__truediv__(other)

dpctl/tests/elementwise/test_add.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ def test_add_errors():
294294

295295
ar1 = dpt.ones(2, dtype="float32")
296296
ar2 = dpt.ones_like(ar1, dtype="int32")
297-
y = ar1
297+
# identical view but a different object
298+
y = ar1[:]
298299
assert_raises_regex(
299300
TypeError,
300301
"Input and output arrays have memory overlap",
@@ -423,7 +424,7 @@ def test_add_inplace_errors():
423424
ar1 = np.ones(2, dtype="float32")
424425
ar2 = dpt.ones(2, dtype="float32")
425426
with pytest.raises(TypeError):
426-
dpt.add.inplace(ar1, ar2)
427+
ar1 += ar2
427428

428429
ar1 = dpt.ones(2, dtype="float32")
429430
ar2 = dict()
@@ -434,3 +435,19 @@ def test_add_inplace_errors():
434435
ar2 = dpt.ones((1, 2), dtype="float32")
435436
with pytest.raises(ValueError):
436437
ar1 += ar2
438+
439+
440+
def test_add_inplace_overlap():
441+
get_queue_or_skip()
442+
443+
ar1 = dpt.ones(10, dtype="i4")
444+
ar1 += ar1
445+
assert (dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype="i4")).all()
446+
447+
ar1 = dpt.ones(10, dtype="i4")
448+
ar2 = dpt.ones(10, dtype="i4")
449+
dpt.add(ar1, ar2, out=ar1)
450+
assert (dpt.asnumpy(ar1) == np.full(ar1.shape, 2, dtype="i4")).all()
451+
452+
dpt.add(ar2, ar1, out=ar2)
453+
assert (dpt.asnumpy(ar2) == np.full(ar2.shape, 3, dtype="i4")).all()

0 commit comments

Comments
 (0)