Skip to content

Fixed dpctl.tensor.result_type function for scalars #1473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@
_empty_like_triple_orderK,
)
from dpctl.tensor._elementwise_common import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_get_dtype,
_get_queue_usm_type,
_get_shape,
_strong_dtype_num_kind,
_validate_dtype,
_weak_type_num_kind,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast, _to_device_supported_dtype
from dpctl.utils import ExecutionPlacementError

from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_strong_dtype_num_kind,
_weak_type_num_kind,
)


def _resolve_one_strong_two_weak_types(st_dtype, dtype1, dtype2, dev):
"Resolves weak data types per NEP-0050,"
Expand Down
126 changes: 5 additions & 121 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_resolve_weak_types,
_to_device_supported_dtype,
)

Expand Down Expand Up @@ -286,46 +291,6 @@ def _get_queue_usm_type(o):
return None, None


class WeakBooleanType:
"Python type representing type of Python boolean objects"

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakIntegralType:
"Python type representing type of Python integral objects"

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakFloatingType:
"""Python type representing type of Python floating point objects"""

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakComplexType:
"""Python type representing type of Python complex floating point objects"""

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
Expand Down Expand Up @@ -375,87 +340,6 @@ def _validate_dtype(dt) -> bool:
)


def _weak_type_num_kind(o):
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
if isinstance(o, WeakBooleanType):
return _map["?"]
if isinstance(o, WeakIntegralType):
return _map["i"]
if isinstance(o, WeakFloatingType):
return _map["f"]
if isinstance(o, WeakComplexType):
return _map["c"]
raise TypeError(
f"Unexpected type {o} while expecting "
"`WeakBooleanType`, `WeakIntegralType`,"
"`WeakFloatingType`, or `WeakComplexType`."
)


def _strong_dtype_num_kind(o):
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
if not isinstance(o, dpt.dtype):
raise TypeError
k = o.kind
if k in _map:
return _map[k]
raise ValueError(f"Unrecognized kind {k} for dtype {o}")


def _resolve_weak_types(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050"
if isinstance(
o1_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if isinstance(
o2_dtype,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
raise ValueError
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakIntegralType):
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if isinstance(o1_dtype, WeakComplexType):
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
return dpt.complex64, o2_dtype
return (
_to_device_supported_dtype(dpt.complex128, dev),
o2_dtype,
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
return o2_dtype, o2_dtype
elif isinstance(
o2_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if isinstance(o2_dtype, WeakComplexType):
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
return o1_dtype, dpt.complex64
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
return (
o1_dtype,
_to_device_supported_dtype(dpt.float64, dev),
)
else:
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
Expand Down
155 changes: 152 additions & 3 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,127 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
return None, None, None


class WeakBooleanType:
"Python type representing type of Python boolean objects"

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakIntegralType:
"Python type representing type of Python integral objects"

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakFloatingType:
"""Python type representing type of Python floating point objects"""

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakComplexType:
"""Python type representing type of Python complex floating point objects"""

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


def _weak_type_num_kind(o):
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
if isinstance(o, WeakBooleanType):
return _map["?"]
if isinstance(o, WeakIntegralType):
return _map["i"]
if isinstance(o, WeakFloatingType):
return _map["f"]
if isinstance(o, WeakComplexType):
return _map["c"]
raise TypeError(
f"Unexpected type {o} while expecting "
"`WeakBooleanType`, `WeakIntegralType`,"
"`WeakFloatingType`, or `WeakComplexType`."
)


def _strong_dtype_num_kind(o):
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
if not isinstance(o, dpt.dtype):
raise TypeError
k = o.kind
if k in _map:
return _map[k]
raise ValueError(f"Unrecognized kind {k} for dtype {o}")


def _resolve_weak_types(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050"
if isinstance(
o1_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if isinstance(
o2_dtype,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
raise ValueError
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakIntegralType):
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if isinstance(o1_dtype, WeakComplexType):
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
return dpt.complex64, o2_dtype
return (
_to_device_supported_dtype(dpt.complex128, dev),
o2_dtype,
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
return o2_dtype, o2_dtype
elif isinstance(
o2_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if isinstance(o2_dtype, WeakComplexType):
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
return o1_dtype, dpt.complex64
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
return (
o1_dtype,
_to_device_supported_dtype(dpt.float64, dev),
)
else:
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype


class finfo_object:
"""
`numpy.finfo` subclass which returns Python floating-point scalars for
Expand Down Expand Up @@ -407,17 +528,27 @@ def result_type(*arrays_and_dtypes):
"""
dtypes = []
devices = []
weak_dtypes = []
for arg_i in arrays_and_dtypes:
if isinstance(arg_i, dpt.usm_ndarray):
devices.append(arg_i.sycl_device)
dtypes.append(arg_i.dtype)
elif isinstance(arg_i, int):
weak_dtypes.append(WeakIntegralType(arg_i))
elif isinstance(arg_i, float):
weak_dtypes.append(WeakFloatingType(arg_i))
elif isinstance(arg_i, complex):
weak_dtypes.append(WeakComplexType(arg_i))
elif isinstance(arg_i, bool):
weak_dtypes.append(WeakBooleanType(arg_i))
else:
dt = dpt.dtype(arg_i)
_supported_dtype([dt])
dtypes.append(dt)

has_fp16 = True
has_fp64 = True
target_dev = None
if devices:
inspected = False
for d in devices:
Expand All @@ -435,17 +566,28 @@ def result_type(*arrays_and_dtypes):
else:
has_fp16 = d.has_aspect_fp16
has_fp64 = d.has_aspect_fp64
target_dev = d
inspected = True

if not (has_fp16 and has_fp64):
for dt in dtypes:
if not _dtype_supported_by_device_impl(dt, has_fp16, has_fp64):
raise ValueError(f"Argument {dt} is not supported by ")
raise ValueError(
f"Argument {dt} is not supported by the device"
)
res_dt = np.result_type(*dtypes)
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
return res_dt
for wdt in weak_dtypes:
pair = _resolve_weak_types(wdt, res_dt, target_dev)
res_dt = np.result_type(*pair)
res_dt = _to_device_supported_dtype_impl(res_dt, has_fp16, has_fp64)
else:
res_dt = np.result_type(*dtypes)
if weak_dtypes:
weak_dt_obj = [wdt.get() for wdt in weak_dtypes]
res_dt = np.result_type(res_dt, *weak_dt_obj)

return np.result_type(*dtypes)
return res_dt


def iinfo(dtype):
Expand Down Expand Up @@ -528,8 +670,15 @@ def _supported_dtype(dtypes):
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
"_resolve_weak_types",
"_weak_type_num_kind",
"_strong_dtype_num_kind",
"can_cast",
"finfo",
"iinfo",
"result_type",
"WeakBooleanType",
"WeakIntegralType",
"WeakFloatingType",
"WeakComplexType",
]
Loading