|
| 1 | +# paddle_ormqr 设计文档 |
| 2 | + |
| 3 | +| API 名称 | paddle.ormqr | |
| 4 | +| ----------------------------------------------------------------- | --------------------------------------- | |
| 5 | +| 提交作者 `<input type="checkbox" class="rowselector hidden">` | Chen-Lun-Hao | |
| 6 | +| 提交时间 `<input type="checkbox" class="rowselector hidden">` | 2024-03-27 | |
| 7 | +| 版本号 | V2.0 | |
| 8 | +| 依赖飞桨版本 `<input type="checkbox" class="rowselector hidden">` | develop | |
| 9 | +| 文件名 | 20240326_api_design_for_ormqr.md `<br>` | |
| 10 | + |
| 11 | +# 一、概述 |
| 12 | + |
| 13 | +## 1、相关背景 |
| 14 | + |
| 15 | +为了提升飞桨 API 丰富度,需要为飞桨扩充 API `paddle.ormqr` |
| 16 | + |
| 17 | +本 API 属于飞桨开源个人贡献赛 API 开发任务[No.28:为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/Paddle/issues/62905)的任务。 |
| 18 | + |
| 19 | +## 2、功能目标 |
| 20 | + |
| 21 | +计算一个普通矩阵与 Householder 矩阵的乘积。计算维度为(m, n)的矩阵 C(由 other 给出)和一个矩阵 Q 的乘积, 其中 Q 由 Householder 反射系数 (x, tau) 表示。 |
| 22 | + |
| 23 | +预期该 API 支持 |
| 24 | + |
| 25 | +- paddle.linalg.ormqr 作为独立的函数调用 |
| 26 | +- Tensor.ormqr 作为 Tensor 的方法使用 |
| 27 | + |
| 28 | +## 3、意义 |
| 29 | + |
| 30 | +为飞桨增加普通矩阵与指定矩阵的乘积的计算方式,提升飞桨 API 丰富度。 |
| 31 | + |
| 32 | +# 二、飞桨现状 |
| 33 | + |
| 34 | +目前飞桨缺少相关功能实现 |
| 35 | + |
| 36 | +# 三、业内方案调研 |
| 37 | + |
| 38 | +## PyTorch |
| 39 | + |
| 40 | +PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor` 以及对应的 `torch.Tensor.ormqr` |
| 41 | + |
| 42 | +其介绍为: |
| 43 | + |
| 44 | +> Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. Multiplies a m×n matrix C (given by other) with a matrix Q, where Q is represented using Householder reflectors (input, tau). |
| 45 | +
|
| 46 | +> 参数表为: |
| 47 | +
|
| 48 | +- `Tensor` input: tensor of shape (_, mn, k) where _ is zero or more batch dimensions and mn equals to m or n depending on the left. |
| 49 | +- `Tensor` tau: tensor of shape (_, min(mn, k)) where _ is zero or more batch dimensions. |
| 50 | +- `Tensor` other: tensor of shape (_, m, n) where _ is zero or more batch dimensions. |
| 51 | +- `bool` left: controls the order of multiplication. |
| 52 | +- `bool` transpose: controls whether the matrix Q is conjugate transposed or not. |
| 53 | + |
| 54 | +### 实现 |
| 55 | + |
| 56 | +PyTorch 在 2.2 版本给出的 API 中,其默认后端 Inductor 针对 `ormqr`操作进行实现的代码如下,具体代码可以参考[BatchLinearAlgebraKernel.cpp](https://github.com/pytorch/pytorch/blob/99c822c0ba747fad8528ff6b57712abdbdc2c093/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp#L2710) |
| 57 | + |
| 58 | +```python |
| 59 | +template <typename scalar_t> |
| 60 | +void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) { |
| 61 | +#if !AT_BUILD_WITH_LAPACK() |
| 62 | + TORCH_CHECK(false, "Calling torch.ormqr on a CPU tensor requires compiling ", |
| 63 | + "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); |
| 64 | +#else |
| 65 | + using value_t = typename c10::scalar_value_type<scalar_t>::type; |
| 66 | + |
| 67 | + char side = left ? 'L' : 'R'; |
| 68 | + char trans = transpose ? (input.is_complex() ? 'C' : 'T') : 'N'; |
| 69 | + |
| 70 | + auto input_data = input.const_data_ptr<scalar_t>(); |
| 71 | + auto tau_data = tau.const_data_ptr<scalar_t>(); |
| 72 | + auto other_data = other.data_ptr<scalar_t>(); |
| 73 | + |
| 74 | + auto input_matrix_stride = matrixStride(input); |
| 75 | + auto other_matrix_stride = matrixStride(other); |
| 76 | + auto tau_stride = tau.size(-1); |
| 77 | + auto batch_size = batchCount(input); |
| 78 | + auto m = other.size(-2); |
| 79 | + auto n = other.size(-1); |
| 80 | + auto k = tau.size(-1); |
| 81 | + auto lda = std::max<int64_t>(1, left ? m : n); |
| 82 | + auto ldc = std::max<int64_t>(1, m); |
| 83 | + int info = 0; |
| 84 | + |
| 85 | + // LAPACK's requirement |
| 86 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY((left ? m : n) >= k); |
| 87 | + |
| 88 | + // Query for the optimal size of the workspace tensor |
| 89 | + int lwork = -1; |
| 90 | + scalar_t wkopt; |
| 91 | + lapackOrmqr<scalar_t>(side, trans, m, n, k, const_cast<scalar_t*>(input_data), lda, const_cast<scalar_t*>(tau_data), other_data, ldc, &wkopt, lwork, &info); |
| 92 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); |
| 93 | + lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt)); |
| 94 | + Tensor work = at::empty({lwork}, input.options()); |
| 95 | + |
| 96 | + for (const auto i : c10::irange(batch_size)) { |
| 97 | + const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride]; |
| 98 | + scalar_t* other_working_ptr = &other_data[i * other_matrix_stride]; |
| 99 | + const scalar_t* tau_working_ptr = &tau_data[i * tau_stride]; |
| 100 | + |
| 101 | + // now compute the actual result |
| 102 | + lapackOrmqr<scalar_t>( |
| 103 | + side, trans, m, n, k, |
| 104 | + const_cast<scalar_t*>(input_working_ptr), lda, |
| 105 | + const_cast<scalar_t*>(tau_working_ptr), |
| 106 | + other_working_ptr, ldc, |
| 107 | + work.data_ptr<scalar_t>(), lwork, &info); |
| 108 | + |
| 109 | + // info from lapackOrmqr only reports if the i-th parameter is wrong |
| 110 | + // so we don't need to check it all the time |
| 111 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0); |
| 112 | + } |
| 113 | +#endif |
| 114 | +} |
| 115 | + |
| 116 | +// This is a type dispatching helper function for 'apply_ormqr' |
| 117 | +void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) { |
| 118 | + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "ormqr_cpu", [&]{ |
| 119 | + apply_ormqr<scalar_t>(input, tau, other, left, transpose); |
| 120 | + }); |
| 121 | +} |
| 122 | +``` |
| 123 | + |
| 124 | +## MindSpore |
| 125 | + |
| 126 | +MindSpore 中有 `mindspore.ops.orqmr` 此接口: |
| 127 | + |
| 128 | +- `mindspore.ops.ormqr(input, tau, other, left=True, transpose=False)` |
| 129 | + |
| 130 | +其实现代码: |
| 131 | + |
| 132 | +https://www.mindspore.cn/docs/zh-CN/master/_modules/mindspore/ops/function/math_func.html#ormqr |
| 133 | + |
| 134 | +```python |
| 135 | +def _get_cache_prim(cls: Primitive) -> Primitive: |
| 136 | + """ |
| 137 | + Wrapper function, get a primitive by it's all args. |
| 138 | +
|
| 139 | + Args: |
| 140 | + cls (Primitive): The Primitive need be wrapped. |
| 141 | +
|
| 142 | + Returns: |
| 143 | + Function, a new function with return a primitive by it's all args. |
| 144 | +
|
| 145 | + Examples: |
| 146 | + >>> # Example1: |
| 147 | + >>> from mindspore.ops._primitive_cache import _get_cache_prim |
| 148 | + >>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32)) |
| 149 | + >>> axis = [0, 1] |
| 150 | + >>> p=2 |
| 151 | + >>> keep_dims=False |
| 152 | + >>> epsilon=1e-12 |
| 153 | + >>> _lp_norm = _get_cache_prim(P.LpNorm)(axis, p, keep_dims, epsilon) |
| 154 | + >>> output = _lp_norm(input_x) |
| 155 | + >>> print(output) |
| 156 | + [ 9.165152 10.954452] |
| 157 | + >>> # Example2: |
| 158 | + >>> from mindspore.ops._primitive_cache import _get_cache_prim |
| 159 | + >>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32)) |
| 160 | + >>> axis = [0, 1] |
| 161 | + >>> _lp_norm = _get_cache_prim(P.LpNorm)(axis, 2, keep_dims=False, epsilon=1e-12) |
| 162 | + >>> output = _lp_norm(input_x) |
| 163 | + >>> print(output) |
| 164 | + [ 9.165152 10.954452] |
| 165 | + """ |
| 166 | + |
| 167 | + def _new_prim_for_graph(*args, **kwargs) -> Primitive: |
| 168 | + return cls(*args, **kwargs) |
| 169 | + |
| 170 | + def _get_cache_prim_for_pynative(*args, **kwargs) -> Primitive: |
| 171 | + """Get a primitive singleton by it's all args.""" |
| 172 | + global _PRIM_CACHE |
| 173 | + key = (str(cls),) |
| 174 | + str_args = [str(arg) for arg in args] |
| 175 | + key += tuple(str_args) |
| 176 | + for attr_name in kwargs: |
| 177 | + attr_value = kwargs.get(attr_name) |
| 178 | + key += (attr_name + ":" + str(attr_value),) |
| 179 | + # Note: The key must be a str. |
| 180 | + key = str(key) |
| 181 | + if key not in _PRIM_CACHE: |
| 182 | + prim = Primitive.__new__(cls, *args, **kwargs) |
| 183 | + # Only init once. |
| 184 | + prim.__init__(*args, **kwargs) |
| 185 | + _PRIM_CACHE[key] = prim |
| 186 | + return _PRIM_CACHE.get(key) |
| 187 | + |
| 188 | + if _is_need_compile(_temp_func): # @jit.cond: True |
| 189 | + return _new_prim_for_graph |
| 190 | + return _get_cache_prim_for_pynative |
| 191 | + |
| 192 | + |
| 193 | +[文档]def ormqr(input, tau, other, left=True, transpose=False): |
| 194 | + r""" |
| 195 | + Calculates two matrices multiplication of a product of a general matrix with Householder matrices. |
| 196 | + Calculates the product of a matrix C(given by `other`) with dimensions (m, n) and a matrix Q which is represented |
| 197 | + using Householder reflectors (`input`, `tau`). Returns a Tensor. |
| 198 | +
|
| 199 | + Args: |
| 200 | + input (Tensor): Tensor of shape :math:`(*, mn, k)`, when `left` is True, mn equals to m, |
| 201 | + otherwise, mn equals to n. And `*` is zero or more batch dimensions. |
| 202 | + tau (Tensor): Tensor of shape :math:`(*, min(mn, k))` where `*` is zero or more batch dimensions, |
| 203 | + and its type is the same as `input`. |
| 204 | + other (Tensor): Tensor of shape :math:`(*, m, n)` where `*` is zero or more batch dimensions, |
| 205 | + and its type is the same as `input`. |
| 206 | + left (bool, optional): determines the order of multiplication. If True, computes op(Q) \* `other` , |
| 207 | + otherwise, compute `other` \* op(Q). Default: ``True`` . |
| 208 | + transpose (bool, optional): If True, the matrix Q is conjugate transposed, |
| 209 | + otherwise, not conjugate transposing matrix Q. Default: ``False`` . |
| 210 | +
|
| 211 | + Returns: |
| 212 | + Tensor, with the same type and shape as `other`. |
| 213 | +
|
| 214 | + Raises: |
| 215 | + TypeError: If `input` or `tau` or `other` is not Tensor. |
| 216 | + TypeError: If dtype of `input` or `tau` or `other` is not one of: float64, float32, complex64, complex128. |
| 217 | + ValueError: If the dimension of `input` or `other` is less than 2D. |
| 218 | + ValueError: If rank(`input`) - rank(`tau`) != 1. |
| 219 | + ValueError: If tau.shape[:-2] != input.shape[:-2] |
| 220 | + ValueError: If other.shape[:-2] != input.shape[:-2] |
| 221 | + ValueError: If left == true, other.shape[-2] < tau.shape[-1]. |
| 222 | + ValueError: If left == true, other.shape[-2] != input.shape[-2]. |
| 223 | + ValueError: If left == false, other.shape[-1] < tau.shape[-1]. |
| 224 | + ValueError: If left == false, other.shape[-1] != input.shape[-2]. |
| 225 | +
|
| 226 | + Supported Platforms: |
| 227 | + ``GPU`` |
| 228 | +
|
| 229 | + Examples: |
| 230 | + >>> import mindspore |
| 231 | + >>> import numpy as np |
| 232 | + >>> from mindspore import Tensor, ops |
| 233 | + >>> input = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]), |
| 234 | + ... mindspore.float32) |
| 235 | + >>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32) |
| 236 | + >>> other = Tensor(np.array([[-114.6, 10.9, 1.1], |
| 237 | + ... [-0.304, 38.07, 69.38], |
| 238 | + ... [-0.45, -0.17, 62]]), mindspore.float32) |
| 239 | + >>> output = ops.ormqr(input, tau, other) |
| 240 | + >>> print(output) |
| 241 | + [[ 63.82713 -13.823125 -116.28614 ] |
| 242 | + [ -53.659264 -28.157839 -70.42702 ] |
| 243 | + [ -79.54292 24.00183 -41.34253 ]] |
| 244 | + """ |
| 245 | + |
| 246 | + ormqr_ = _get_cache_prim(Ormqr)(left, transpose) |
| 247 | + return ormqr_(input, tau, other) |
| 248 | + |
| 249 | +``` |
| 250 | + |
| 251 | +# 四、对比分析 |
| 252 | + |
| 253 | +对比 PyTorch 与 MindSpore: |
| 254 | + |
| 255 | +- 实现方式不同 |
| 256 | + |
| 257 | + PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。 |
| 258 | + |
| 259 | +# 五、设计思路与实现方案 |
| 260 | + |
| 261 | +paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 `input` 与 `othrt` 具有相同的 `ndim`,因此,不需要使用 `decrease_axes` 等参数。 |
| 262 | + |
| 263 | +## 命名与参数设计 |
| 264 | + |
| 265 | +添加 Python API: |
| 266 | + |
| 267 | +```python |
| 268 | +paddle.orqmr(input, tau, other, left=True, transpose=False) |
| 269 | +``` |
| 270 | + |
| 271 | +参数表: |
| 272 | + |
| 273 | +- input: (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。 |
| 274 | +- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。 |
| 275 | +- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。 |
| 276 | +- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True。 |
| 277 | +- transpose: (bool, 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False。 |
| 278 | + |
| 279 | +## 底层 OP 设计 |
| 280 | + |
| 281 | +不涉及底层 OP。 |
| 282 | + |
| 283 | +# 六、测试和验收的考量 |
| 284 | + |
| 285 | +- GPU 测试场景 |
| 286 | +- 支持各种 Tensor |
| 287 | +- 需要检查计算正确性 |
| 288 | +- 需要检查多维的情况 |
| 289 | + |
| 290 | +# 七、可行性分析和排期规划 |
| 291 | + |
| 292 | +有业内方案实现作为参考,相关 PythonAPI 均有实现,可以在开源贡献个人挑战赛期间完成。 |
| 293 | + |
| 294 | +# 八、影响面 |
| 295 | + |
| 296 | +对其他模块暂无影响 |
| 297 | + |
| 298 | +# 名词解释 |
| 299 | + |
| 300 | +# 附件及参考资料 |
| 301 | + |
| 302 | +[【Hackathon 6th No.4】为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/community/pull/668) |
| 303 | +[PyTorch slice_scatter 文档](https://pytorch.org/docs/stable/generated/torch.slice_scatter.html) |
0 commit comments