|
34 | 34 | __all__ = ["dpnp_dot", "dpnp_matmul"] |
35 | 35 |
|
36 | 36 |
|
37 | | -def _op_res_dtype(*arrays, dtype, casting, sycl_queue): |
| 37 | +def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): |
38 | 38 | """ |
39 | | - _op_res_dtype(*arrays, dtype, casting, sycl_queue) |
40 | | -
|
41 | | - Determines the output array data type and an intermediate data type |
42 | | - used in performing calculations related to a specific math function. |
43 | | - If dtype is ``None``, the output array data type of the operation is |
44 | | - determined based on the Promotion Type Rule and device capabilities. |
45 | | - Otherwise, `dtype` is used as output array dtype, if input arrays |
46 | | - can cast to it according to the casting rule determined. If casting |
47 | | - cannot be done, a ``TypeError`` is raised. |
48 | | - The intermediate data type is the data type used for performing the math |
49 | | - function calculations. If output array dtype is a floating-point data type, |
50 | | - it is also used for the intermediate data type. If output array dtype is an |
51 | | - integral data type, the default floating point data type of the device where |
52 | | - input arrays are allocated on are used for intermediate data type. |
53 | | -
|
54 | | - Parameters |
55 | | - ---------- |
56 | | - arrays : {dpnp.ndarray, usm_ndarray} |
57 | | - Input arrays. |
58 | | - dtype : dtype |
59 | | - If not ``None``, data type of the output array. |
60 | | - casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional |
61 | | - Controls what kind of data casting may occur. |
62 | | - sycl_queue : {SyclQueue} |
63 | | - A SYCL queue to use for determining default floating point datat type. |
| 39 | + Creating a copy of input array if needed. |
64 | 40 |
|
65 | | - Returns |
66 | | - ------- |
67 | | - op_dtype, res_dtype : |
68 | | - `op_dtype` is the data type used in performing math function calculations. |
69 | | - The input arrays of the math function are cast to `op_dtype` and then |
70 | | - the calculations are performed. |
71 | | - `res_dtype` is the output data type. When the result is obtained, it is cast |
72 | | - to `res_dtype`. |
| 41 | + If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. |
| 42 | + In this case, the copy array has the input array data type unless `dtype` is |
| 43 | + determined. |
| 44 | + If `contig_copy` is ``False`` and input array data type is different than `dtype`, |
| 45 | + a C-contiguous copy of input array with specified `dtype` is returned. |
73 | 46 |
|
74 | 47 | """ |
75 | 48 |
|
76 | | - res_dtype = dpnp.result_type(*arrays) |
77 | | - default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) |
78 | | - |
79 | | - if dtype is not None: |
80 | | - if dpnp.can_cast(res_dtype, dtype, casting=casting): |
81 | | - res_dtype = dtype |
82 | | - else: |
83 | | - raise TypeError( |
84 | | - f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" |
85 | | - ) |
86 | | - |
87 | | - op_dtype = ( |
88 | | - res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype |
89 | | - ) |
| 49 | + if contig_copy: |
| 50 | + copy = contig_copy |
| 51 | + else: |
| 52 | + copy = x.dtype != dtype if dtype is not None else False |
90 | 53 |
|
91 | | - return op_dtype, res_dtype |
| 54 | + if copy: |
| 55 | + x_copy = dpnp.empty_like(x, dtype=dtype, order="C") |
| 56 | + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
| 57 | + src=dpnp.get_usm_ndarray(x), |
| 58 | + dst=x_copy.get_array(), |
| 59 | + sycl_queue=x.sycl_queue, |
| 60 | + ) |
| 61 | + dep_events.append(copy_ev) |
| 62 | + host_events.append(ht_copy_ev) |
| 63 | + return x_copy |
| 64 | + return x |
92 | 65 |
|
93 | 66 |
|
94 | 67 | def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): |
@@ -153,34 +126,61 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list): |
153 | 126 | return ht_blas_ev, ht_tasks_list, res |
154 | 127 |
|
155 | 128 |
|
156 | | -def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): |
| 129 | +def _op_res_dtype(*arrays, dtype, casting, sycl_queue): |
157 | 130 | """ |
158 | | - Creating a copy of input array if needed. |
| 131 | + _op_res_dtype(*arrays, dtype, casting, sycl_queue) |
159 | 132 |
|
160 | | - If `contig_copy` is ``True``, a C-contiguous copy of input array is returned. |
161 | | - In this case, the copy array has the input array data type unless `dtype` is |
162 | | - determined. |
163 | | - If `contig_copy` is ``False`` and input array data type is different than `dtype`, |
164 | | - a C-contiguous copy of input array with specified `dtype` is returned. |
| 133 | + Determines the output array data type and an intermediate data type |
| 134 | + used in performing calculations related to a specific math function. |
| 135 | + If dtype is ``None``, the output array data type of the operation is |
| 136 | + determined based on the Promotion Type Rule and device capabilities. |
| 137 | + Otherwise, `dtype` is used as output array dtype, if input arrays |
| 138 | + can cast to it according to the casting rule determined. If casting |
| 139 | + cannot be done, a ``TypeError`` is raised. |
| 140 | + The intermediate data type is the data type used for performing the math |
| 141 | + function calculations. If output array dtype is a floating-point data type, |
| 142 | + it is also used for the intermediate data type. If output array dtype is an |
| 143 | + integral data type, the default floating point data type of the device where |
| 144 | + input arrays are allocated on are used for intermediate data type. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + arrays : {dpnp.ndarray, usm_ndarray} |
| 149 | + Input arrays. |
| 150 | + dtype : dtype |
| 151 | + If not ``None``, data type of the output array. |
| 152 | + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional |
| 153 | + Controls what kind of data casting may occur. |
| 154 | + sycl_queue : {SyclQueue} |
| 155 | + A SYCL queue to use for determining default floating point datat type. |
| 156 | +
|
| 157 | + Returns |
| 158 | + ------- |
| 159 | + op_dtype, res_dtype : |
| 160 | + `op_dtype` is the data type used in performing math function calculations. |
| 161 | + The input arrays of the math function are cast to `op_dtype` and then |
| 162 | + the calculations are performed. |
| 163 | + `res_dtype` is the output data type. When the result is obtained, it is cast |
| 164 | + to `res_dtype`. |
165 | 165 |
|
166 | 166 | """ |
167 | 167 |
|
168 | | - if contig_copy: |
169 | | - copy = contig_copy |
170 | | - else: |
171 | | - copy = x.dtype != dtype if dtype is not None else False |
| 168 | + res_dtype = dpnp.result_type(*arrays) |
| 169 | + default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue) |
172 | 170 |
|
173 | | - if copy: |
174 | | - x_copy = dpnp.empty_like(x, dtype=dtype, order="C") |
175 | | - ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( |
176 | | - src=dpnp.get_usm_ndarray(x), |
177 | | - dst=x_copy.get_array(), |
178 | | - sycl_queue=x.sycl_queue, |
179 | | - ) |
180 | | - dep_events.append(copy_ev) |
181 | | - host_events.append(ht_copy_ev) |
182 | | - return x_copy |
183 | | - return x |
| 171 | + if dtype is not None: |
| 172 | + if dpnp.can_cast(res_dtype, dtype, casting=casting): |
| 173 | + res_dtype = dtype |
| 174 | + else: |
| 175 | + raise TypeError( |
| 176 | + f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}" |
| 177 | + ) |
| 178 | + |
| 179 | + op_dtype = ( |
| 180 | + res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype |
| 181 | + ) |
| 182 | + |
| 183 | + return op_dtype, res_dtype |
184 | 184 |
|
185 | 185 |
|
186 | 186 | def dpnp_dot( |
@@ -394,6 +394,11 @@ def dpnp_matmul( |
394 | 394 | dtype=gemm_dtype, |
395 | 395 | ) |
396 | 396 |
|
| 397 | + # TODO: investigate usage of gemv (gemv_batch) function |
| 398 | + # from BLAS when one of the inputs is a vector to |
| 399 | + # gain performance. |
| 400 | + # TODO: investigate usage of syrk function from BLAS in |
| 401 | + # case of a.T @ a and a @ a.T to gain performance. |
397 | 402 | if x1_is_2D and x2_is_2D: |
398 | 403 | ht_blas_ev, _ = bi._gemm( |
399 | 404 | exec_q, |
|
0 commit comments