Skip to content

Commit b03beb3

Browse files
committed
Fixes incorrect result when multiplying real array and complex scalar
- Adds a test for the fix - Resolves #1219
1 parent 521867b commit b03beb3

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
255255
raise ValueError
256256
o1_kind_num = _weak_type_num_kind(o1_dtype)
257257
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
258-
if o1_kind_num > o2_kind_num:
258+
if o1_kind_num > o2_kind_num or o1_kind_num == 2:
259259
if isinstance(o1_dtype, WeakBooleanType):
260260
return dpt.bool, o2_dtype
261261
if isinstance(o1_dtype, WeakIntegralType):
@@ -273,7 +273,7 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
273273
):
274274
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
275275
o2_kind_num = _weak_type_num_kind(o2_dtype)
276-
if o2_kind_num > o1_kind_num:
276+
if o2_kind_num > o1_kind_num or o2_kind_num == 2:
277277
if isinstance(o2_dtype, WeakBooleanType):
278278
return o1_dtype, dpt.bool
279279
if isinstance(o2_dtype, WeakIntegralType):

dpctl/tests/elementwise/test_multiply.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,18 @@ def test_multiply_python_scalar(arr_dt):
152152
assert isinstance(R, dpt.usm_ndarray)
153153
R = dpt.multiply(sc, X)
154154
assert isinstance(R, dpt.usm_ndarray)
155+
156+
157+
def test_multiply_python_scalar_gh1219():
158+
q = get_queue_or_skip()
159+
160+
X = dpt.ones(4, dtype="f4", sycl_queue=q)
161+
162+
r = dpt.multiply(X, 2j)
163+
expected = dpt.multiply(X, dpt.asarray(2j, sycl_queue=q))
164+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
165+
166+
# symmetric case
167+
r = dpt.multiply(2j, X)
168+
expected = dpt.multiply(dpt.asarray(2j, sycl_queue=q), X)
169+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)

0 commit comments

Comments
 (0)