Skip to content

In-place addition, multiplication, subtraction of usm_ndarrays #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 13, 2023
134 changes: 132 additions & 2 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_empty_like_pair_orderK,
_find_buf_dtype,
_find_buf_dtype2,
_find_inplace_dtype,
_to_device_supported_dtype,
)

Expand Down Expand Up @@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
Class that implements binary element-wise functions.
"""

def __init__(self, name, result_type_resolver_fn, binary_dp_impl_fn, docs):
def __init__(
self,
name,
result_type_resolver_fn,
binary_dp_impl_fn,
docs,
binary_inplace_fn=None,
):
self.__name__ = "BinaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.binary_fn_ = binary_dp_impl_fn
self.binary_inplace_fn_ = binary_inplace_fn
self.__doc__ = docs

def __str__(self):
Expand All @@ -345,6 +354,13 @@ def __repr__(self):
return f"<BinaryElementwiseFunc '{self.name_}'>"

def __call__(self, o1, o2, out=None, order="K"):
# FIXME: replace with check against base array
# when views can be identified
if o1 is out:
return self._inplace(o1, o2)
elif o2 is out:
return self._inplace(o2, o1)

if order not in ["K", "C", "F", "A"]:
order = "K"
q1, o1_usm_type = _get_queue_usm_type(o1)
Expand Down Expand Up @@ -388,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
raise TypeError(
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
try:
res_shape = _broadcast_shape_impl(
Expand Down Expand Up @@ -415,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):

if res_dt is None:
raise TypeError(
"function 'add' does not support input types "
f"function '{self.name_}' does not support input types "
f"({o1_dtype}, {o2_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''safe''."
Expand Down Expand Up @@ -631,3 +648,116 @@ def __call__(self, o1, o2, out=None, order="K"):
)
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
return out

def _inplace(self, lhs, val):
if self.binary_inplace_fn_ is None:
raise ValueError(
f"In-place operation not supported for ufunc '{self.name_}'"
)
if not isinstance(lhs, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
)
q1, lhs_usm_type = _get_queue_usm_type(lhs)
q2, val_usm_type = _get_queue_usm_type(val)
if q2 is None:
exec_q = q1
usm_type = lhs_usm_type
else:
exec_q = dpctl.utils.get_execution_queue((q1, q2))
if exec_q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments."
)
usm_type = dpctl.utils.get_coerced_usm_type(
(
lhs_usm_type,
val_usm_type,
)
)
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
lhs_shape = _get_shape(lhs)
val_shape = _get_shape(val)
if not all(
isinstance(s, (tuple, list))
for s in (
lhs_shape,
val_shape,
)
):
raise TypeError(
"Shape of arguments can not be inferred. "
"Arguments are expected to be "
"lists, tuples, or both"
)
try:
res_shape = _broadcast_shape_impl(
[
lhs_shape,
val_shape,
]
)
except ValueError:
raise ValueError(
"operands could not be broadcast together with shapes "
f"{lhs_shape} and {val_shape}"
)
if res_shape != lhs_shape:
raise ValueError(
f"output shape {lhs_shape} does not match "
f"broadcast shape {res_shape}"
)
sycl_dev = exec_q.sycl_device
lhs_dtype = lhs.dtype
val_dtype = _get_dtype(val, sycl_dev)
if not _validate_dtype(val_dtype):
raise ValueError("Input operand of unsupported type")

lhs_dtype, val_dtype = _resolve_weak_types(
lhs_dtype, val_dtype, sycl_dev
)

buf_dt = _find_inplace_dtype(
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
)

if buf_dt is None:
raise TypeError(
f"In-place '{self.name_}' does not support input types "
f"({lhs_dtype}, {val_dtype}), "
"and the inputs could not be safely coerced to any "
"supported types according to the casting rule ''safe''."
)

if isinstance(val, dpt.usm_ndarray):
rhs = val
overlap = ti._array_overlap(lhs, rhs)
else:
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
overlap = False

if buf_dt == val_dtype and overlap is False:
rhs = dpt.broadcast_to(rhs, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs, rhs=rhs, sycl_queue=exec_q
)
ht_.wait()

else:
buf = dpt.empty_like(rhs, dtype=buf_dt)
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=rhs, dst=buf, sycl_queue=exec_q
)

buf = dpt.broadcast_to(buf, res_shape)
ht_, _ = self.binary_inplace_fn_(
lhs=lhs,
rhs=buf,
sycl_queue=exec_q,
depends=[copy_ev],
)
ht_copy_ev.wait()
ht_.wait()

return lhs
18 changes: 15 additions & 3 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@
returned array is determined by the Type Promotion Rules.
"""
add = BinaryElementwiseFunc(
"add", ti._add_result_type, ti._add, _add_docstring_
"add",
ti._add_result_type,
ti._add,
_add_docstring_,
binary_inplace_fn=ti._add_inplace,
)

# U04: ===== ASIN (x)
Expand Down Expand Up @@ -603,7 +607,11 @@
the returned array is determined by the Type Promotion Rules.
"""
multiply = BinaryElementwiseFunc(
"multiply", ti._multiply_result_type, ti._multiply, _multiply_docstring_
"multiply",
ti._multiply_result_type,
ti._multiply,
_multiply_docstring_,
ti._multiply_inplace,
)

# U25: ==== NEGATIVE (x)
Expand Down Expand Up @@ -782,7 +790,11 @@
of the returned array is determined by the Type Promotion Rules.
"""
subtract = BinaryElementwiseFunc(
"subtract", ti._subtract_result_type, ti._subtract, _subtract_docstring_
"subtract",
ti._subtract_result_type,
ti._subtract,
_subtract_docstring_,
ti._subtract_inplace,
)


Expand Down
18 changes: 18 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,27 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev):
return None, None, None


def _find_inplace_dtype(lhs_dtype, rhs_dtype, query_fn, sycl_dev):
res_dt = query_fn(lhs_dtype, rhs_dtype)
if res_dt and res_dt == lhs_dtype:
return rhs_dtype

_fp16 = sycl_dev.has_aspect_fp16
_fp64 = sycl_dev.has_aspect_fp64
all_dts = _all_data_types(_fp16, _fp64)
for buf_dt in all_dts:
if _can_cast(rhs_dtype, buf_dt, _fp16, _fp64):
res_dt = query_fn(lhs_dtype, buf_dt)
if res_dt and res_dt == lhs_dtype:
return buf_dt

return None


__all__ = [
"_find_buf_dtype",
"_find_buf_dtype2",
"_find_inplace_dtype",
"_empty_like_orderK",
"_empty_like_pair_orderK",
"_to_device_supported_dtype",
Expand Down
21 changes: 6 additions & 15 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1245,11 +1245,8 @@ cdef class usm_ndarray:
return _dispatch_binary_elementwise2(other, "logical_xor", self)

def __iadd__(self, other):
res = self.__add__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import add
return add._inplace(self, other)

def __iand__(self, other):
res = self.__and__(other)
Expand Down Expand Up @@ -1287,11 +1284,8 @@ cdef class usm_ndarray:
return self

def __imul__(self, other):
res = self.__mul__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import multiply
return multiply._inplace(self, other)

def __ior__(self, other):
res = self.__or__(other)
Expand All @@ -1315,11 +1309,8 @@ cdef class usm_ndarray:
return self

def __isub__(self, other):
res = self.__sub__(other)
if res is NotImplemented:
return res
self.__setitem__(Ellipsis, res)
return self
from ._elementwise_funcs import subtract
return subtract._inplace(self, other)

def __itruediv__(self, other):
res = self.__truediv__(other)
Expand Down
Loading