Skip to content

Commit 81553f8

Browse files
authored
Merge pull request #1237 from IntelPython/inplace-operator-initial-impl
In-place addition, multiplication, subtraction of usm_ndarrays
2 parents 43f3b7b + da5f2f7 commit 81553f8

File tree

14 files changed

+1378
-23
lines changed

14 files changed

+1378
-23
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_empty_like_pair_orderK,
3232
_find_buf_dtype,
3333
_find_buf_dtype2,
34+
_find_inplace_dtype,
3435
_to_device_supported_dtype,
3536
)
3637

@@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
331332
Class that implements binary element-wise functions.
332333
"""
333334

334-
def __init__(self, name, result_type_resolver_fn, binary_dp_impl_fn, docs):
335+
def __init__(
336+
self,
337+
name,
338+
result_type_resolver_fn,
339+
binary_dp_impl_fn,
340+
docs,
341+
binary_inplace_fn=None,
342+
):
335343
self.__name__ = "BinaryElementwiseFunc"
336344
self.name_ = name
337345
self.result_type_resolver_fn_ = result_type_resolver_fn
338346
self.binary_fn_ = binary_dp_impl_fn
347+
self.binary_inplace_fn_ = binary_inplace_fn
339348
self.__doc__ = docs
340349

341350
def __str__(self):
@@ -345,6 +354,13 @@ def __repr__(self):
345354
return f"<BinaryElementwiseFunc '{self.name_}'>"
346355

347356
def __call__(self, o1, o2, out=None, order="K"):
357+
# FIXME: replace with check against base array
358+
# when views can be identified
359+
if o1 is out:
360+
return self._inplace(o1, o2)
361+
elif o2 is out:
362+
return self._inplace(o2, o1)
363+
348364
if order not in ["K", "C", "F", "A"]:
349365
order = "K"
350366
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -388,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
388404
raise TypeError(
389405
"Shape of arguments can not be inferred. "
390406
"Arguments are expected to be "
407+
"lists, tuples, or both"
391408
)
392409
try:
393410
res_shape = _broadcast_shape_impl(
@@ -415,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):
415432

416433
if res_dt is None:
417434
raise TypeError(
418-
"function 'add' does not support input types "
435+
f"function '{self.name_}' does not support input types "
419436
f"({o1_dtype}, {o2_dtype}), "
420437
"and the inputs could not be safely coerced to any "
421438
"supported types according to the casting rule ''safe''."
@@ -631,3 +648,116 @@ def __call__(self, o1, o2, out=None, order="K"):
631648
)
632649
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
633650
return out
651+
652+
def _inplace(self, lhs, val):
653+
if self.binary_inplace_fn_ is None:
654+
raise ValueError(
655+
f"In-place operation not supported for ufunc '{self.name_}'"
656+
)
657+
if not isinstance(lhs, dpt.usm_ndarray):
658+
raise TypeError(
659+
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
660+
)
661+
q1, lhs_usm_type = _get_queue_usm_type(lhs)
662+
q2, val_usm_type = _get_queue_usm_type(val)
663+
if q2 is None:
664+
exec_q = q1
665+
usm_type = lhs_usm_type
666+
else:
667+
exec_q = dpctl.utils.get_execution_queue((q1, q2))
668+
if exec_q is None:
669+
raise ExecutionPlacementError(
670+
"Execution placement can not be unambiguously inferred "
671+
"from input arguments."
672+
)
673+
usm_type = dpctl.utils.get_coerced_usm_type(
674+
(
675+
lhs_usm_type,
676+
val_usm_type,
677+
)
678+
)
679+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
680+
lhs_shape = _get_shape(lhs)
681+
val_shape = _get_shape(val)
682+
if not all(
683+
isinstance(s, (tuple, list))
684+
for s in (
685+
lhs_shape,
686+
val_shape,
687+
)
688+
):
689+
raise TypeError(
690+
"Shape of arguments can not be inferred. "
691+
"Arguments are expected to be "
692+
"lists, tuples, or both"
693+
)
694+
try:
695+
res_shape = _broadcast_shape_impl(
696+
[
697+
lhs_shape,
698+
val_shape,
699+
]
700+
)
701+
except ValueError:
702+
raise ValueError(
703+
"operands could not be broadcast together with shapes "
704+
f"{lhs_shape} and {val_shape}"
705+
)
706+
if res_shape != lhs_shape:
707+
raise ValueError(
708+
f"output shape {lhs_shape} does not match "
709+
f"broadcast shape {res_shape}"
710+
)
711+
sycl_dev = exec_q.sycl_device
712+
lhs_dtype = lhs.dtype
713+
val_dtype = _get_dtype(val, sycl_dev)
714+
if not _validate_dtype(val_dtype):
715+
raise ValueError("Input operand of unsupported type")
716+
717+
lhs_dtype, val_dtype = _resolve_weak_types(
718+
lhs_dtype, val_dtype, sycl_dev
719+
)
720+
721+
buf_dt = _find_inplace_dtype(
722+
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
723+
)
724+
725+
if buf_dt is None:
726+
raise TypeError(
727+
f"In-place '{self.name_}' does not support input types "
728+
f"({lhs_dtype}, {val_dtype}), "
729+
"and the inputs could not be safely coerced to any "
730+
"supported types according to the casting rule ''safe''."
731+
)
732+
733+
if isinstance(val, dpt.usm_ndarray):
734+
rhs = val
735+
overlap = ti._array_overlap(lhs, rhs)
736+
else:
737+
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
738+
overlap = False
739+
740+
if buf_dt == val_dtype and overlap is False:
741+
rhs = dpt.broadcast_to(rhs, res_shape)
742+
ht_, _ = self.binary_inplace_fn_(
743+
lhs=lhs, rhs=rhs, sycl_queue=exec_q
744+
)
745+
ht_.wait()
746+
747+
else:
748+
buf = dpt.empty_like(rhs, dtype=buf_dt)
749+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
750+
src=rhs, dst=buf, sycl_queue=exec_q
751+
)
752+
753+
buf = dpt.broadcast_to(buf, res_shape)
754+
ht_, _ = self.binary_inplace_fn_(
755+
lhs=lhs,
756+
rhs=buf,
757+
sycl_queue=exec_q,
758+
depends=[copy_ev],
759+
)
760+
ht_copy_ev.wait()
761+
ht_.wait()
762+
763+
return lhs

dpctl/tensor/_elementwise_funcs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@
7373
returned array is determined by the Type Promotion Rules.
7474
"""
7575
add = BinaryElementwiseFunc(
76-
"add", ti._add_result_type, ti._add, _add_docstring_
76+
"add",
77+
ti._add_result_type,
78+
ti._add,
79+
_add_docstring_,
80+
binary_inplace_fn=ti._add_inplace,
7781
)
7882

7983
# U04: ===== ASIN (x)
@@ -603,7 +607,11 @@
603607
the returned array is determined by the Type Promotion Rules.
604608
"""
605609
multiply = BinaryElementwiseFunc(
606-
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
610+
"multiply",
611+
ti._multiply_result_type,
612+
ti._multiply,
613+
_multiply_docstring_,
614+
ti._multiply_inplace,
607615
)
608616

609617
# U25: ==== NEGATIVE (x)
@@ -782,7 +790,11 @@
782790
of the returned array is determined by the Type Promotion Rules.
783791
"""
784792
subtract = BinaryElementwiseFunc(
785-
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
793+
"subtract",
794+
ti._subtract_result_type,
795+
ti._subtract,
796+
_subtract_docstring_,
797+
ti._subtract_inplace,
786798
)
787799

788800

dpctl/tensor/_type_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,27 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
294294
return None, None, None
295295

296296

297+
def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
298+
res_dt = query_fn(lhs_dtype, rhs_dtype)
299+
if res_dt and res_dt == lhs_dtype:
300+
return rhs_dtype
301+
302+
_fp16 = sycl_dev.has_aspect_fp16
303+
_fp64 = sycl_dev.has_aspect_fp64
304+
all_dts = _all_data_types(_fp16, _fp64)
305+
for buf_dt in all_dts:
306+
if _can_cast(rhs_dtype, buf_dt, _fp16, _fp64):
307+
res_dt = query_fn(lhs_dtype, buf_dt)
308+
if res_dt and res_dt == lhs_dtype:
309+
return buf_dt
310+
311+
return None
312+
313+
297314
__all__ = [
298315
"_find_buf_dtype",
299316
"_find_buf_dtype2",
317+
"_find_inplace_dtype",
300318
"_empty_like_orderK",
301319
"_empty_like_pair_orderK",
302320
"_to_device_supported_dtype",

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,11 +1245,8 @@ cdef class usm_ndarray:
12451245
return _dispatch_binary_elementwise2(other, "logical_xor", self)
12461246

12471247
def __iadd__(self, other):
1248-
res = self.__add__(other)
1249-
if res is NotImplemented:
1250-
return res
1251-
self.__setitem__(Ellipsis, res)
1252-
return self
1248+
from ._elementwise_funcs import add
1249+
return add._inplace(self, other)
12531250

12541251
def __iand__(self, other):
12551252
res = self.__and__(other)
@@ -1287,11 +1284,8 @@ cdef class usm_ndarray:
12871284
return self
12881285

12891286
def __imul__(self, other):
1290-
res = self.__mul__(other)
1291-
if res is NotImplemented:
1292-
return res
1293-
self.__setitem__(Ellipsis, res)
1294-
return self
1287+
from ._elementwise_funcs import multiply
1288+
return multiply._inplace(self, other)
12951289

12961290
def __ior__(self, other):
12971291
res = self.__or__(other)
@@ -1315,11 +1309,8 @@ cdef class usm_ndarray:
13151309
return self
13161310

13171311
def __isub__(self, other):
1318-
res = self.__sub__(other)
1319-
if res is NotImplemented:
1320-
return res
1321-
self.__setitem__(Ellipsis, res)
1322-
return self
1312+
from ._elementwise_funcs import subtract
1313+
return subtract._inplace(self, other)
13231314

13241315
def __itruediv__(self, other):
13251316
res = self.__truediv__(other)

0 commit comments

Comments
 (0)