Skip to content

Commit e07d5f0

Browse files
committed
Changed logic for in-place arithmetic operations on usm_ndarrays with themselves
- Now makes a copy and adds the copy to the original array
1 parent 0f9c857 commit e07d5f0

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,6 @@ def inplace(self, lhs, val):
650650
raise TypeError(
651651
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
652652
)
653-
if isinstance(val, dpt.usm_ndarray):
654-
if ti._array_overlap(lhs, val):
655-
# call standard operator in this case
656-
return self(lhs, val)
657653
q1, lhs_usm_type = _get_queue_usm_type(lhs)
658654
q2, val_usm_type = _get_queue_usm_type(val)
659655
if q2 is None:
@@ -727,10 +723,12 @@ def inplace(self, lhs, val):
727723

728724
if isinstance(val, dpt.usm_ndarray):
729725
rhs = val
726+
overlap = ti._array_overlap(lhs, rhs)
730727
else:
731728
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
729+
overlap = False
732730

733-
if buf_dt == val_dtype:
731+
if buf_dt == val_dtype and overlap is False:
734732
rhs = dpt.broadcast_to(rhs, res_shape)
735733
ht_, _ = self.binary_inplace_fn_(
736734
lhs=lhs, rhs=rhs, sycl_queue=exec_q

0 commit comments

Comments
 (0)