Skip to content

Adds element-wise functions angle and reciprocal #1474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
Expand Down Expand Up @@ -87,6 +88,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
Expand Down
4 changes: 4 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
acos,
acosh,
add,
angle,
asin,
asinh,
atan,
Expand Down Expand Up @@ -153,6 +154,7 @@
pow,
proj,
real,
reciprocal,
remainder,
round,
rsqrt,
Expand Down Expand Up @@ -342,4 +344,6 @@
"var",
"__array_api_version__",
"__array_namespace_info__",
"reciprocal",
"angle",
]
52 changes: 48 additions & 4 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
Expand Down Expand Up @@ -62,17 +63,39 @@ class UnaryElementwiseFunc:
computational tasks complete execution, while the second event
corresponds to computational tasks associated with function
evaluation.
acceptance_fn (callable, optional):
Function to influence type promotion behavior of this unary
function. The function takes 4 arguments:
arg_dtype - Data type of the first argument
buf_dtype - Data type the argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - The :class:`dpctl.SyclDevice` where the function
evaluation is carried out.
The function is invoked when the argument of the unary function
requires casting, e.g. the argument of `dpctl.tensor.log` is an
array with integral data type.
docs (str):
Documentation string for the unary function.
"""

def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
def __init__(
self,
name,
result_type_resolver_fn,
unary_dp_impl_fn,
docs,
acceptance_fn=None,
):
self.__name__ = "UnaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.types_ = None
self.unary_fn_ = unary_dp_impl_fn
self.__doc__ = docs
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default_unary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand All @@ -93,6 +116,24 @@ def get_type_result_resolver_function(self):
"""
return self.result_type_resolver_fn_

def get_type_promotion_path_acceptance_function(self):
"""Returns the acceptance function for this
elementwise binary function.

Acceptance function influences the type promotion
behavior of this unary function.
The function takes 4 arguments:
arg_dtype - Data type of the first argument
buf_dtype - Data type the argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - The :class:`dpctl.SyclDevice` where the function
evaluation is carried out.
The function is invoked when the argument of the unary function
requires casting, e.g. the argument of `dpctl.tensor.log` is an
array with integral data type.
"""
return self.acceptance_fn_

@property
def types(self):
"""Returns information about types supported by
Expand Down Expand Up @@ -122,7 +163,10 @@ def __call__(self, x, out=None, order="K"):
if order not in ["C", "F", "K", "A"]:
order = "K"
buf_dt, res_dt = _find_buf_dtype(
x.dtype, self.result_type_resolver_fn_, x.sycl_device
x.dtype,
self.result_type_resolver_fn_,
x.sycl_device,
acceptance_fn=self.acceptance_fn_,
)
if res_dt is None:
raise TypeError(
Expand Down Expand Up @@ -482,7 +526,7 @@ def __init__(
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default
self.acceptance_fn_ = _acceptance_fn_default_binary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand Down
66 changes: 64 additions & 2 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dpctl.tensor._tensor_elementwise_impl as ti

from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
from ._type_utils import _acceptance_fn_divide
from ._type_utils import _acceptance_fn_divide, _acceptance_fn_reciprocal

# U01: ==== ABS (x)
_abs_docstring_ = """
Expand Down Expand Up @@ -1880,10 +1880,72 @@
Returns:
usm_narray:
An array containing the element-wise reciprocal square-root.
The data type of the returned array is determined by
The returned array has a floating-point data type determined by
the Type Promotion Rules.
"""

rsqrt = UnaryElementwiseFunc(
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
)


# U42: ==== RECIPROCAL (x)
_reciprocal_docstring = """
reciprocal(x, out=None, order='K')

Computes the reciprocal of each element `x_i` for input array `x`.

Args:
x (usm_ndarray):
Input array, expected to have a real-valued floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise reciprocals.
The returned array has a floating-point data type determined
by the Type Promotion Rules.
"""

reciprocal = UnaryElementwiseFunc(
"reciprocal",
ti._reciprocal_result_type,
ti._reciprocal,
_reciprocal_docstring,
acceptance_fn=_acceptance_fn_reciprocal,
)


# U43: ==== ANGLE (x)
_angle_docstring = """
angle(x, out=None, order='K')

Computes the phase angle (also called the argument) of each element `x_i` for
input array `x`.

Args:
x (usm_ndarray):
Input array, expected to have a complex-valued floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise phase angles.
The returned array has a floating-point data type determined
by the Type Promotion Rules.
"""

angle = UnaryElementwiseFunc(
"angle",
ti._angle_result_type,
ti._angle,
_angle_docstring,
)
34 changes: 30 additions & 4 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,27 @@ def _to_device_supported_dtype(dt, dev):
return dt


def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
return True


def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
# if the kind of result is different from
# the kind of input, use the default data
# we use default dtype for the resulting kind.
# This guarantees alignment of reciprocal and
# divide output types.
if buf_dt.kind != arg_dtype.kind:
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
return True
else:
return False
else:
return True


def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
res_dt = query_fn(arg_dtype)
if res_dt:
return None, res_dt
Expand All @@ -144,7 +164,11 @@ def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
res_dt = query_fn(buf_dt)
if res_dt:
return buf_dt, res_dt
acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
if acceptable:
return buf_dt, res_dt
else:
continue

return None, None

Expand All @@ -163,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
raise RuntimeError


def _acceptance_fn_default(
def _acceptance_fn_default_binary(
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
):
return True
Expand Down Expand Up @@ -230,6 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
"_find_buf_dtype",
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default",
"_acceptance_fn_default_unary",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
]
15 changes: 9 additions & 6 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ cdef class usm_ndarray:
""" Returns real component for arrays with complex data-types
and returns itself for all other data-types.
"""
if (self.typenum_ < UAR_CFLOAT):
# explicitly check for UAR_HALF, which is greater than UAR_CFLOAT
if (self.typenum_ < UAR_CFLOAT or self.typenum_ == UAR_HALF):
# elements are real
return self
if (self.typenum_ < UAR_TYPE_SENTINEL):
Expand All @@ -698,7 +699,8 @@ cdef class usm_ndarray:
""" Returns imaginary component for arrays with complex data-types
and returns zero array for all other data-types.
"""
if (self.typenum_ < UAR_CFLOAT):
# explicitly check for UAR_HALF, which is greater than UAR_CFLOAT
if (self.typenum_ < UAR_CFLOAT or self.typenum_ == UAR_HALF):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

# elements are real
return _zero_like(self)
if (self.typenum_ < UAR_TYPE_SENTINEL):
Expand Down Expand Up @@ -1306,14 +1308,15 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):

cdef usm_ndarray _zero_like(usm_ndarray ary):
"""
Make C-contiguous array of zero elements with same shape
and type as ary.
Make C-contiguous array of zero elements with same shape,
type, device, and sycl_queue as ary.
"""
cdef dt = _make_typestr(ary.typenum_)
cdef usm_ndarray r = usm_ndarray(
_make_int_tuple(ary.nd_, ary.shape_),
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
dtype=dt,
buffer=ary.base_.get_usm_type()
buffer=ary.base_.get_usm_type(),
buffer_ctor_kwargs={"queue": ary.get_sycl_queue()},
)
r.base_.memset()
return r
Expand Down
Loading