Skip to content

Commit d5ccdfd

Browse files
authored
dpnp.multiply() doesn't work properly with a scalar (#1254)
* dpnp.multiply() doesn't work properly with a scalar * Fix typo in description * Return scalar, if both inputs are scalars
1 parent 51938b0 commit d5ccdfd

File tree

11 files changed

+212
-142
lines changed

11 files changed

+212
-142
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2016-2020, Intel Corporation
2+
// Copyright (c) 2016-2022, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without
@@ -183,6 +183,7 @@
183183
where, \
184184
dep_event_vec_ref); \
185185
DPCTLEvent_WaitAndThrow(event_ref); \
186+
DPCTLEvent_Delete(event_ref); \
186187
} \
187188
\
188189
template <typename _DataType_input, typename _DataType_output> \
@@ -690,6 +691,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
690691
where, \
691692
dep_event_vec_ref); \
692693
DPCTLEvent_WaitAndThrow(event_ref); \
694+
DPCTLEvent_Delete(event_ref); \
693695
} \
694696
\
695697
template <typename _DataType> \
@@ -1067,6 +1069,7 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
10671069
dep_event_vec_ref \
10681070
); \
10691071
DPCTLEvent_WaitAndThrow(event_ref); \
1072+
DPCTLEvent_Delete(event_ref); \
10701073
} \
10711074
\
10721075
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> \
@@ -1732,36 +1735,56 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
17321735
eft_FLT, (void*)dpnp_multiply_c_ext<float, bool, float>};
17331736
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_DBL] = {
17341737
eft_DBL, (void*)dpnp_multiply_c_ext<double, bool, double>};
1738+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_C64] = {
1739+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, bool, std::complex<float>>};
1740+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_BLN][eft_C128] = {
1741+
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, bool, std::complex<double>>};
1742+
17351743
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_BLN] = {
17361744
eft_INT, (void*)dpnp_multiply_c_ext<int32_t, int32_t, bool>};
17371745
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_INT] = {
17381746
eft_INT, (void*)dpnp_multiply_c_ext<int32_t, int32_t, int32_t>};
17391747
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_LNG] = {
17401748
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int32_t, int64_t>};
17411749
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_FLT] = {
1742-
eft_DBL, (void*)dpnp_multiply_c_ext<double, int32_t, float>};
1750+
eft_FLT, (void*)dpnp_multiply_c_ext<float, int32_t, float>};
17431751
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_DBL] = {
17441752
eft_DBL, (void*)dpnp_multiply_c_ext<double, int32_t, double>};
1753+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_C64] = {
1754+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, int32_t, std::complex<float>>};
1755+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_INT][eft_C128] = {
1756+
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, int32_t, std::complex<double>>};
1757+
17451758
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_BLN] = {
17461759
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, bool>};
17471760
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_INT] = {
17481761
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, int32_t>};
17491762
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_LNG] = {
17501763
eft_LNG, (void*)dpnp_multiply_c_ext<int64_t, int64_t, int64_t>};
17511764
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_FLT] = {
1752-
eft_DBL, (void*)dpnp_multiply_c_ext<double, int64_t, float>};
1765+
eft_FLT, (void*)dpnp_multiply_c_ext<float, int64_t, float>};
17531766
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_DBL] = {
17541767
eft_DBL, (void*)dpnp_multiply_c_ext<double, int64_t, double>};
1768+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_C64] = {
1769+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, int64_t, std::complex<float>>};
1770+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_LNG][eft_C128] = {
1771+
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, int64_t, std::complex<double>>};
1772+
17551773
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_BLN] = {
17561774
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, bool>};
17571775
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_INT] = {
1758-
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, int32_t>};
1776+
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, int32_t>};
17591777
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_LNG] = {
1760-
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, int64_t>};
1778+
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, int64_t>};
17611779
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_FLT] = {
17621780
eft_FLT, (void*)dpnp_multiply_c_ext<float, float, float>};
17631781
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_DBL] = {
17641782
eft_DBL, (void*)dpnp_multiply_c_ext<double, float, double>};
1783+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_C64] = {
1784+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, float, std::complex<float>>};
1785+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_FLT][eft_C128] = {
1786+
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, float, std::complex<double>>};
1787+
17651788
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_BLN] = {
17661789
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, bool>};
17671790
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_INT] = {
@@ -1772,6 +1795,10 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
17721795
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, float>};
17731796
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_DBL] = {
17741797
eft_DBL, (void*)dpnp_multiply_c_ext<double, double, double>};
1798+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_C64] = {
1799+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, double, std::complex<float>>};
1800+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_DBL][eft_C128] = {
1801+
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, double, std::complex<double>>};
17751802

17761803
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_BLN] = {
17771804
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, bool>};
@@ -1782,7 +1809,7 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
17821809
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_FLT] = {
17831810
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, float>};
17841811
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_DBL] = {
1785-
eft_C128, (void*)dpnp_multiply_c_ext<std::complex<double>, std::complex<float>, double>};
1812+
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, double>};
17861813
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_C64] = {
17871814
eft_C64, (void*)dpnp_multiply_c_ext<std::complex<float>, std::complex<float>, std::complex<float>>};
17881815
fmap[DPNPFuncName::DPNP_FN_MULTIPLY_EXT][eft_C64][eft_C128] = {

dpnp/dpnp_array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def asnumpy(self):
437437
438438
Returns
439439
-------
440+
numpy.ndarray
440441
An instance of :class:`numpy.ndarray` populated with the array content.
441442
442443
"""

dpnp/dpnp_iface.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
"get_normalized_queue_device"
6969
]
7070

71+
from dpnp import (
72+
isscalar
73+
)
74+
7175
from dpnp.dpnp_iface_arraycreation import *
7276
from dpnp.dpnp_iface_bitwise import *
7377
from dpnp.dpnp_iface_counting import *
@@ -187,7 +191,10 @@ def convert_single_elem_array_to_scalar(obj, keepdims=False):
187191
return obj
188192

189193

190-
def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_queue=True):
194+
def get_dpnp_descriptor(ext_obj,
195+
copy_when_strides=True,
196+
copy_when_nondefault_queue=True,
197+
alloc_queue=None):
191198
"""
192199
Return True:
193200
never
@@ -206,6 +213,11 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_qu
206213
if use_origin_backend():
207214
return False
208215

216+
# If input object is a scalar, it means it was allocated on host memory.
217+
# We need to copy it to device memory according to compute follows data paradigm.
218+
if isscalar(ext_obj):
219+
ext_obj = array(ext_obj, sycl_queue=alloc_queue)
220+
209221
# while dpnp functions have no implementation with strides support
210222
# we need to create a non-strided copy
211223
# if function get implementation for strides case
@@ -226,13 +238,12 @@ def get_dpnp_descriptor(ext_obj, copy_when_strides=True, copy_when_nondefault_qu
226238
# we need to create a copy on device associated with DPNP_QUEUE
227239
# if function get implementation for different queue
228240
# then this behavior can be disabled with setting "copy_when_nondefault_queue"
229-
arr_obj = unwrap_array(ext_obj)
230-
queue = getattr(arr_obj, "sycl_queue", None)
241+
queue = getattr(ext_obj, "sycl_queue", None)
231242
if queue is not None and copy_when_nondefault_queue:
232243
default_queue = dpctl.SyclQueue()
233244
queue_is_default = dpctl.utils.get_execution_queue([queue, default_queue]) is not None
234245
if not queue_is_default:
235-
ext_obj = array(arr_obj, sycl_queue=default_queue)
246+
ext_obj = array(ext_obj, sycl_queue=default_queue)
236247

237248
dpnp_desc = dpnp_descriptor(ext_obj)
238249
if dpnp_desc.is_valid:

dpnp/dpnp_iface_manipulation.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# distutils: language = c++
33
# -*- coding: utf-8 -*-
44
# *****************************************************************************
5-
# Copyright (c) 2016-2020, Intel Corporation
5+
# Copyright (c) 2016-2022, Intel Corporation
66
# All rights reserved.
77
#
88
# Redistribution and use in source and binary forms, with or without
@@ -131,12 +131,13 @@ def atleast_2d(*arys):
131131
all_is_array = True
132132
arys_desc = []
133133
for ary in arys:
134-
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
135-
if ary_desc:
136-
arys_desc.append(ary_desc)
137-
else:
138-
all_is_array = False
139-
break
134+
if not dpnp.isscalar(ary):
135+
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
136+
if ary_desc:
137+
arys_desc.append(ary_desc)
138+
continue
139+
all_is_array = False
140+
break
140141

141142
if not use_origin_backend(arys[0]) and all_is_array:
142143
result = []
@@ -166,12 +167,13 @@ def atleast_3d(*arys):
166167
all_is_array = True
167168
arys_desc = []
168169
for ary in arys:
169-
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
170-
if ary_desc:
171-
arys_desc.append(ary_desc)
172-
else:
173-
all_is_array = False
174-
break
170+
if not dpnp.isscalar(ary):
171+
ary_desc = dpnp.get_dpnp_descriptor(ary, copy_when_nondefault_queue=False)
172+
if ary_desc:
173+
arys_desc.append(ary_desc)
174+
continue
175+
all_is_array = False
176+
break
175177

176178
if not use_origin_backend(arys[0]) and all_is_array:
177179
result = []

dpnp/dpnp_iface_mathematical.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# distutils: language = c++
33
# -*- coding: utf-8 -*-
44
# *****************************************************************************
5-
# Copyright (c) 2016-2020, Intel Corporation
5+
# Copyright (c) 2016-2022, Intel Corporation
66
# All rights reserved.
77
#
88
# Redistribution and use in source and binary forms, with or without
@@ -850,9 +850,9 @@ def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
850850
pass
851851
elif x1_is_scalar and x2_is_scalar:
852852
pass
853-
elif x1_desc and x1.ndim == 0:
853+
elif x1_desc and x1_desc.ndim == 0:
854854
pass
855-
elif x2_desc and x2.ndim == 0:
855+
elif x2_desc and x2_desc.ndim == 0:
856856
pass
857857
elif dtype is not None:
858858
pass
@@ -1075,51 +1075,62 @@ def modf(x1, **kwargs):
10751075
return call_origin(numpy.modf, x1, **kwargs)
10761076

10771077

1078-
def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
1078+
def multiply(x1,
1079+
x2,
1080+
/,
1081+
out=None,
1082+
*,
1083+
where=True,
1084+
dtype=None,
1085+
subok=True,
1086+
**kwargs):
10791087
"""
10801088
Multiply arguments element-wise.
10811089
10821090
For full documentation refer to :obj:`numpy.multiply`.
10831091
1092+
Returns
1093+
-------
1094+
y : {dpnp.ndarray, scalar}
1095+
The product of `x1` and `x2`, element-wise.
1096+
The result is a scalar if both x1 and x2 are scalars.
1097+
10841098
Limitations
10851099
-----------
1086-
Parameters ``x1`` and ``x2`` are supported as either :obj:`dpnp.ndarray` or scalar.
1087-
Parameters ``dtype``, ``out`` and ``where`` are supported with their default values.
1100+
Parameters `x1` and `x2` are supported as either :class:`dpnp.ndarray` or scalar.
1101+
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
10881102
Keyword arguments ``kwargs`` are currently unsupported.
10891103
Otherwise the functions will be executed sequentially on CPU.
10901104
Input array data types are limited by supported DPNP :ref:`Data types`.
10911105
10921106
Examples
10931107
--------
1094-
>>> import dpnp as np
1095-
>>> a = np.array([1, 2, 3, 4, 5])
1096-
>>> result = np.multiply(a, a)
1097-
>>> [x for x in result]
1108+
>>> import dpnp as dp
1109+
>>> a = dp.array([1, 2, 3, 4, 5])
1110+
>>> result = dp.multiply(a, a)
1111+
>>> print(result)
10981112
[1, 4, 9, 16, 25]
10991113
11001114
"""
11011115

1102-
x1_is_scalar = dpnp.isscalar(x1)
1103-
x2_is_scalar = dpnp.isscalar(x2)
1104-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
1105-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
1106-
1107-
if x1_desc and x2_desc and not kwargs:
1108-
if not x2_desc and not x2_is_scalar:
1109-
pass
1110-
elif x1_is_scalar and x2_is_scalar:
1111-
pass
1112-
elif x1_desc and x1_desc.ndim == 0:
1113-
pass
1114-
elif x2_desc and x2_desc.ndim == 0:
1115-
pass
1116-
elif dtype is not None:
1117-
pass
1118-
elif out is not None:
1119-
pass
1120-
elif not where:
1121-
pass
1122-
else:
1116+
if out is not None:
1117+
pass
1118+
elif where is not True:
1119+
pass
1120+
elif dtype is not None:
1121+
pass
1122+
elif subok is not True:
1123+
pass
1124+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
1125+
# keep the result in host memory, if both inputs are scalars
1126+
return x1 * x2
1127+
else:
1128+
# get a common queue to copy data from the host into a device if any input is scalar
1129+
queue = get_common_allocation_queue([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else None
1130+
1131+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
1132+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False, alloc_queue=queue)
1133+
if x1_desc and x2_desc:
11231134
return dpnp_multiply(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
11241135

11251136
return call_origin(numpy.multiply, x1, x2, dtype=dtype, out=out, where=where, **kwargs)

0 commit comments

Comments
 (0)