Skip to content

Fixed behavior of mathematical functions for scalars. #1233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,18 @@ def get(self):
return self.o_


class WeakInexactType:
"""Python type representing type of Python real- or
complex-valued floating point objects"""
class WeakFloatingType:
"""Python type representing type of Python floating point objects"""

def __init__(self, o):
self.o_ = o

def get(self):
return self.o_


class WeakComplexType:
"""Python type representing type of Python complex floating point objects"""

def __init__(self, o):
self.o_ = o
Expand All @@ -189,14 +198,17 @@ def _get_dtype(o, dev):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, (float, complex)):
return WeakInexactType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt, (WeakBooleanType, WeakInexactType, WeakIntegralType)
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
Expand All @@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool:


def _weak_type_num_kind(o):
_map = {"?": 0, "i": 1, "f": 2}
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
if isinstance(o, WeakBooleanType):
return _map["?"]
if isinstance(o, WeakIntegralType):
return _map["i"]
if isinstance(o, WeakInexactType):
if isinstance(o, WeakFloatingType):
return _map["f"]
if isinstance(o, WeakComplexType):
return _map["c"]
raise TypeError(
f"Unexpected type {o} while expecting "
"`WeakBooleanType`, `WeakIntegralType`, or "
"`WeakInexactType`."
"`WeakBooleanType`, `WeakIntegralType`,"
"`WeakFloatingType`, or `WeakComplexType`."
)


def _strong_dtype_num_kind(o):
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 2}
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
if not isinstance(o, dpt.dtype):
raise TypeError
k = o.kind
Expand All @@ -247,20 +261,29 @@ def _strong_dtype_num_kind(o):
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050"
if isinstance(
o1_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
o1_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
if isinstance(
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
o2_dtype,
(
WeakBooleanType,
WeakIntegralType,
WeakFloatingType,
WeakComplexType,
),
):
raise ValueError
o1_kind_num = _weak_type_num_kind(o1_dtype)
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
if o1_kind_num > o2_kind_num or o1_kind_num == 2:
if o1_kind_num > o2_kind_num:
if isinstance(o1_dtype, WeakBooleanType):
return dpt.bool, o2_dtype
if isinstance(o1_dtype, WeakIntegralType):
return dpt.int64, o2_dtype
if isinstance(o1_dtype.get(), complex):
if isinstance(o1_dtype, WeakComplexType):
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
return dpt.complex64, o2_dtype
return (
_to_device_supported_dtype(dpt.complex128, dev),
o2_dtype,
Expand All @@ -269,16 +292,19 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
else:
return o2_dtype, o2_dtype
elif isinstance(
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
o2_dtype,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
o2_kind_num = _weak_type_num_kind(o2_dtype)
if o2_kind_num > o1_kind_num or o2_kind_num == 2:
if o2_kind_num > o1_kind_num:
if isinstance(o2_dtype, WeakBooleanType):
return o1_dtype, dpt.bool
if isinstance(o2_dtype, WeakIntegralType):
return o1_dtype, dpt.int64
if isinstance(o2_dtype.get(), complex):
if isinstance(o2_dtype, WeakComplexType):
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
return o1_dtype, dpt.complex64
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
return (
o1_dtype,
Expand Down
21 changes: 13 additions & 8 deletions dpctl/tests/elementwise/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,21 @@ def test_multiply_python_scalar(arr_dt):
assert isinstance(R, dpt.usm_ndarray)


def test_multiply_python_scalar_gh1219():
@pytest.mark.parametrize("arr_dt", _all_dtypes)
@pytest.mark.parametrize("sc", [bool(1), int(1), float(1), complex(1)])
def test_multiply_python_scalar_gh1219(arr_dt, sc):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arr_dt, q)

X = dpt.ones(4, dtype="f4", sycl_queue=q)
Xnp = np.ones(4, dtype=arr_dt)

r = dpt.multiply(X, 2j)
expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q))
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
X = dpt.ones(4, dtype=arr_dt, sycl_queue=q)

R = dpt.multiply(X, sc)
Rnp = np.multiply(Xnp, sc)
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)

# symmetric case
r = dpt.multiply(2j, X)
expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X)
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
R = dpt.multiply(sc, X)
Rnp = np.multiply(sc, Xnp)
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)