|
69 | 69 |
|
70 | 70 |
|
71 | 71 | # TODO: implement a specific scalar-array kernel
|
72 |
| -def _call_multiply(a, b, out=None): |
73 |
| - """Call multiply function for special cases of scalar-array dots.""" |
| 72 | +def _call_multiply(a, b, out=None, outer_calc=False): |
| 73 | + """ |
| 74 | + Adjusted multiply function for handling special cases of scalar-array dot |
| 75 | + products in linear algebra. |
| 76 | +
|
| 77 | + `dpnp.multiply` cannot directly be used for calculating scalar-array dots, |
| 78 | + because the output dtype of multiply is not the same as the expected dtype |
| 79 | + for scalar-array dots. For example, if `sc` is an scalar and `a` is an |
| 80 | + array of type `float32`, then `dpnp.multiply(a, sc).dtype == dpnp.float32` |
| 81 | + (similar to NumPy). However, for scalar-array dots, such as the dot |
| 82 | + function, we need `dpnp.dot(a, sc).dtype == dpnp.float64` to align with |
| 83 | + NumPy. This functions adjusts the behavior of `dpnp.multiply` function to |
| 84 | + meet this requirement. |
| 85 | +
|
| 86 | + """ |
74 | 87 |
|
75 | 88 | sc, arr = (a, b) if dpnp.isscalar(a) else (b, a)
|
76 | 89 | sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device)
|
77 | 90 | res_dtype = dpnp.result_type(sc_dtype, arr)
|
| 91 | + multiply_func = dpnp.multiply.outer if outer_calc else dpnp.multiply |
78 | 92 | if out is not None and out.dtype == arr.dtype:
|
79 |
| - res = dpnp.multiply(a, b, out=out) |
| 93 | + res = multiply_func(a, b, out=out) |
80 | 94 | else:
|
81 |
| - res = dpnp.multiply(a, b, dtype=res_dtype) |
| 95 | + res = multiply_func(a, b, dtype=res_dtype) |
82 | 96 | return dpnp.get_result_array(res, out, casting="no")
|
83 | 97 |
|
84 | 98 |
|
@@ -1109,16 +1123,15 @@ def outer(a, b, out=None):
|
1109 | 1123 |
|
1110 | 1124 | dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False)
|
1111 | 1125 | if dpnp.isscalar(a):
|
1112 |
| - x1 = a |
1113 | 1126 | x2 = dpnp.ravel(b)[None, :]
|
| 1127 | + result = _call_multiply(a, x2, out=out, outer_calc=True) |
1114 | 1128 | elif dpnp.isscalar(b):
|
1115 | 1129 | x1 = dpnp.ravel(a)[:, None]
|
1116 |
| - x2 = b |
| 1130 | + result = _call_multiply(x1, b, out=out, outer_calc=True) |
1117 | 1131 | else:
|
1118 |
| - x1 = dpnp.ravel(a) |
1119 |
| - x2 = dpnp.ravel(b) |
| 1132 | + result = dpnp.multiply.outer(dpnp.ravel(a), dpnp.ravel(b), out=out) |
1120 | 1133 |
|
1121 |
| - return dpnp.multiply.outer(x1, x2, out=out) |
| 1134 | + return result |
1122 | 1135 |
|
1123 | 1136 |
|
1124 | 1137 | def tensordot(a, b, axes=2):
|
@@ -1288,13 +1301,13 @@ def vdot(a, b):
|
1288 | 1301 | if b.size != 1:
|
1289 | 1302 | raise ValueError("The second array should be of size one.")
|
1290 | 1303 | a_conj = numpy.conj(a)
|
1291 |
| - return _call_multiply(a_conj, b) |
| 1304 | + return dpnp.squeeze(_call_multiply(a_conj, b)) |
1292 | 1305 |
|
1293 | 1306 | if dpnp.isscalar(b):
|
1294 | 1307 | if a.size != 1:
|
1295 | 1308 | raise ValueError("The first array should be of size one.")
|
1296 | 1309 | a_conj = dpnp.conj(a)
|
1297 |
| - return _call_multiply(a_conj, b) |
| 1310 | + return dpnp.squeeze(_call_multiply(a_conj, b)) |
1298 | 1311 |
|
1299 | 1312 | if a.ndim == 1 and b.ndim == 1:
|
1300 | 1313 | return dpnp_dot(a, b, out=None, conjugate=True)
|
|
0 commit comments