Skip to content

Device-aware can_cast and result_type #1488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,12 @@
from dpctl.tensor._manipulation_functions import (
broadcast_arrays,
broadcast_to,
can_cast,
concat,
expand_dims,
finfo,
flip,
iinfo,
moveaxis,
permute_dims,
repeat,
result_type,
roll,
squeeze,
stack,
Expand Down Expand Up @@ -180,6 +176,7 @@
sum,
)
from ._testing import allclose
from ._type_utils import can_cast, finfo, iinfo, result_type

__all__ = [
"Device",
Expand Down
1 change: 1 addition & 0 deletions dpctl/tensor/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _get_dtype(inp_dt, sycl_obj, ref_type=None):

__all__ = [
"dtype",
"_get_dtype",
"isdtype",
"bool",
"int8",
Expand Down
217 changes: 1 addition & 216 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,101 +27,14 @@
import dpctl.utils as dputils

from ._copy_utils import _broadcast_strides
from ._type_utils import _to_device_supported_dtype
from ._type_utils import _supported_dtype, _to_device_supported_dtype

__doc__ = (
"Implementation module for array manipulation "
"functions in :module:`dpctl.tensor`"
)


class finfo_object:
"""
`numpy.finfo` subclass which returns Python floating-point scalars for
`eps`, `max`, `min`, and `smallest_normal` attributes.
"""

def __init__(self, dtype):
_supported_dtype([dpt.dtype(dtype)])
self._finfo = np.finfo(dtype)

@property
def bits(self):
"""
number of bits occupied by the real-valued floating-point data type.
"""
return int(self._finfo.bits)

@property
def smallest_normal(self):
"""
smallest positive real-valued floating-point number with full
precision.
"""
return float(self._finfo.smallest_normal)

@property
def tiny(self):
"""an alias for `smallest_normal`"""
return float(self._finfo.tiny)

@property
def eps(self):
"""
difference between 1.0 and the next smallest representable real-valued
floating-point number larger than 1.0 according to the IEEE-754
standard.
"""
return float(self._finfo.eps)

@property
def epsneg(self):
"""
difference between 1.0 and the next smallest representable real-valued
floating-point number smaller than 1.0 according to the IEEE-754
standard.
"""
return float(self._finfo.epsneg)

@property
def min(self):
"""smallest representable real-valued number."""
return float(self._finfo.min)

@property
def max(self):
"largest representable real-valued number."
return float(self._finfo.max)

@property
def resolution(self):
"the approximate decimal resolution of this type."
return float(self._finfo.resolution)

@property
def precision(self):
"""
the approximate number of decimal digits to which this kind of
floating point type is precise.
"""
return float(self._finfo.precision)

@property
def dtype(self):
"""
the dtype for which finfo returns information. For complex input, the
returned dtype is the associated floating point dtype for its real and
complex components.
"""
return self._finfo.dtype

def __str__(self):
return self._finfo.__str__()

def __repr__(self):
return self._finfo.__repr__()


def _broadcast_shape_impl(shapes):
if len(set(shapes)) == 1:
return shapes[0]
Expand Down Expand Up @@ -681,127 +594,6 @@ def stack(arrays, axis=0):
return res


def can_cast(from_, to, casting="safe"):
""" can_cast(from, to, casting="safe")

Determines if one data type can be cast to another data type according \
to Type Promotion Rules.

Args:
from (usm_ndarray, dtype): source data type
to (dtype): target data type
casting ({'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional):
controls what kind of data casting may occur.

Returns:
bool:
Gives `True` if cast can occur according to the casting rule.
"""
if isinstance(to, dpt.usm_ndarray):
raise TypeError("Expected dtype type.")

dtype_to = dpt.dtype(to)

dtype_from = (
from_.dtype if isinstance(from_, dpt.usm_ndarray) else dpt.dtype(from_)
)

_supported_dtype([dtype_from, dtype_to])

return np.can_cast(dtype_from, dtype_to, casting)


def result_type(*arrays_and_dtypes):
"""
result_type(arrays_and_dtypes)

Returns the dtype that results from applying the Type Promotion Rules to \
the arguments.

Args:
arrays_and_dtypes (object):
An arbitrary length sequence of arrays or dtypes.

Returns:
dtype:
The dtype resulting from an operation involving the
input arrays and dtypes.
"""
dtypes = [
X.dtype if isinstance(X, dpt.usm_ndarray) else dpt.dtype(X)
for X in arrays_and_dtypes
]

_supported_dtype(dtypes)

return np.result_type(*dtypes)


def iinfo(dtype):
"""iinfo(dtype)

Returns machine limits for integer data types.

Args:
dtype (dtype, usm_ndarray):
integer dtype or
an array with integer dtype.

Returns:
iinfo_object:
An object with the following attributes
* bits: int
number of bits occupied by the data type
* max: int
largest representable number.
* min: int
smallest representable number.
* dtype: dtype
integer data type.
"""
if isinstance(dtype, dpt.usm_ndarray):
dtype = dtype.dtype
_supported_dtype([dpt.dtype(dtype)])
return np.iinfo(dtype)


def finfo(dtype):
"""finfo(type)

Returns machine limits for floating-point data types.

Args:
dtype (dtype, usm_ndarray): floating-point dtype or
an array with floating point data type.
If complex, the information is about its component
data type.

Returns:
finfo_object:
an object have the following attributes
* bits: int
number of bits occupied by dtype.
* eps: float
difference between 1.0 and the next smallest representable
real-valued floating-point number larger than 1.0 according
to the IEEE-754 standard.
* max: float
largest representable real-valued number.
* min: float
smallest representable real-valued number.
* smallest_normal: float
smallest positive real-valued floating-point number with
full precision.
* dtype: dtype
real-valued floating-point data type.

"""
if isinstance(dtype, dpt.usm_ndarray):
dtype = dtype.dtype
_supported_dtype([dpt.dtype(dtype)])
return finfo_object(dtype)


def unstack(X, axis=0):
"""unstack(x, axis=0)

Expand Down Expand Up @@ -1229,10 +1021,3 @@ def tile(x, repetitions):
)
hev.wait()
return dpt.reshape(res, res_shape)


def _supported_dtype(dtypes):
for dtype in dtypes:
if dtype.char not in "?bBhHiIlLqQefdFD":
raise ValueError(f"Dpctl doesn't support dtype {dtype}.")
return True
Loading