Skip to content

Commit 6bb09e7

Browse files
authored
Merge pull request #1233 from IntelPython/fix_multiply_scalar_dtype
Fixed behavior of mathematical functions for floating-point and complex floating-point scalars.
2 parents 368a17e + 4a09766 commit 6bb09e7

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,18 @@ def get(self):
162162
return self.o_
163163

164164

165-
class WeakInexactType:
166-
"""Python type representing type of Python real- or
167-
complex-valued floating point objects"""
165+
class WeakFloatingType:
166+
"""Python type representing type of Python floating point objects"""
167+
168+
def __init__(self, o):
169+
self.o_ = o
170+
171+
def get(self):
172+
return self.o_
173+
174+
175+
class WeakComplexType:
176+
"""Python type representing type of Python complex floating point objects"""
168177

169178
def __init__(self, o):
170179
self.o_ = o
@@ -189,14 +198,17 @@ def _get_dtype(o, dev):
189198
return WeakBooleanType(o)
190199
if isinstance(o, int):
191200
return WeakIntegralType(o)
192-
if isinstance(o, (float, complex)):
193-
return WeakInexactType(o)
201+
if isinstance(o, float):
202+
return WeakFloatingType(o)
203+
if isinstance(o, complex):
204+
return WeakComplexType(o)
194205
return np.object_
195206

196207

197208
def _validate_dtype(dt) -> bool:
198209
return isinstance(
199-
dt, (WeakBooleanType, WeakInexactType, WeakIntegralType)
210+
dt,
211+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
200212
) or (
201213
isinstance(dt, dpt.dtype)
202214
and dt
@@ -220,22 +232,24 @@ def _validate_dtype(dt) -> bool:
220232

221233

222234
def _weak_type_num_kind(o):
223-
_map = {"?": 0, "i": 1, "f": 2}
235+
_map = {"?": 0, "i": 1, "f": 2, "c": 3}
224236
if isinstance(o, WeakBooleanType):
225237
return _map["?"]
226238
if isinstance(o, WeakIntegralType):
227239
return _map["i"]
228-
if isinstance(o, WeakInexactType):
240+
if isinstance(o, WeakFloatingType):
229241
return _map["f"]
242+
if isinstance(o, WeakComplexType):
243+
return _map["c"]
230244
raise TypeError(
231245
f"Unexpected type {o} while expecting "
232-
"`WeakBooleanType`, `WeakIntegralType`, or "
233-
"`WeakInexactType`."
246+
"`WeakBooleanType`, `WeakIntegralType`,"
247+
"`WeakFloatingType`, or `WeakComplexType`."
234248
)
235249

236250

237251
def _strong_dtype_num_kind(o):
238-
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 2}
252+
_map = {"b": 0, "i": 1, "u": 1, "f": 2, "c": 3}
239253
if not isinstance(o, dpt.dtype):
240254
raise TypeError
241255
k = o.kind
@@ -247,20 +261,29 @@ def _strong_dtype_num_kind(o):
247261
def _resolve_weak_types(o1_dtype, o2_dtype, dev):
248262
"Resolves weak data type per NEP-0050"
249263
if isinstance(
250-
o1_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
264+
o1_dtype,
265+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
251266
):
252267
if isinstance(
253-
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
268+
o2_dtype,
269+
(
270+
WeakBooleanType,
271+
WeakIntegralType,
272+
WeakFloatingType,
273+
WeakComplexType,
274+
),
254275
):
255276
raise ValueError
256277
o1_kind_num = _weak_type_num_kind(o1_dtype)
257278
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
258-
if o1_kind_num > o2_kind_num or o1_kind_num == 2:
279+
if o1_kind_num > o2_kind_num:
259280
if isinstance(o1_dtype, WeakBooleanType):
260281
return dpt.bool, o2_dtype
261282
if isinstance(o1_dtype, WeakIntegralType):
262283
return dpt.int64, o2_dtype
263-
if isinstance(o1_dtype.get(), complex):
284+
if isinstance(o1_dtype, WeakComplexType):
285+
if o2_dtype is dpt.float16 or o2_dtype is dpt.float32:
286+
return dpt.complex64, o2_dtype
264287
return (
265288
_to_device_supported_dtype(dpt.complex128, dev),
266289
o2_dtype,
@@ -269,16 +292,19 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
269292
else:
270293
return o2_dtype, o2_dtype
271294
elif isinstance(
272-
o2_dtype, (WeakBooleanType, WeakInexactType, WeakIntegralType)
295+
o2_dtype,
296+
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
273297
):
274298
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
275299
o2_kind_num = _weak_type_num_kind(o2_dtype)
276-
if o2_kind_num > o1_kind_num or o2_kind_num == 2:
300+
if o2_kind_num > o1_kind_num:
277301
if isinstance(o2_dtype, WeakBooleanType):
278302
return o1_dtype, dpt.bool
279303
if isinstance(o2_dtype, WeakIntegralType):
280304
return o1_dtype, dpt.int64
281-
if isinstance(o2_dtype.get(), complex):
305+
if isinstance(o2_dtype, WeakComplexType):
306+
if o1_dtype is dpt.float16 or o1_dtype is dpt.float32:
307+
return o1_dtype, dpt.complex64
282308
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
283309
return (
284310
o1_dtype,

dpctl/tests/elementwise/test_multiply.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,21 @@ def test_multiply_python_scalar(arr_dt):
154154
assert isinstance(R, dpt.usm_ndarray)
155155

156156

157-
def test_multiply_python_scalar_gh1219():
157+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
158+
@pytest.mark.parametrize("sc", [bool(1), int(1), float(1), complex(1)])
159+
def test_multiply_python_scalar_gh1219(arr_dt, sc):
158160
q = get_queue_or_skip()
161+
skip_if_dtype_not_supported(arr_dt, q)
159162

160-
X = dpt.ones(4, dtype="f4", sycl_queue=q)
163+
Xnp = np.ones(4, dtype=arr_dt)
161164

162-
r = dpt.multiply(X, 2j)
163-
expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q))
164-
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
165+
X = dpt.ones(4, dtype=arr_dt, sycl_queue=q)
166+
167+
R = dpt.multiply(X, sc)
168+
Rnp = np.multiply(Xnp, sc)
169+
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)
165170

166171
# symmetric case
167-
r = dpt.multiply(2j, X)
168-
expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X)
169-
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
172+
R = dpt.multiply(sc, X)
173+
Rnp = np.multiply(sc, Xnp)
174+
assert _compare_dtypes(R.dtype, Rnp.dtype, sycl_queue=q)

0 commit comments

Comments
 (0)