11# paddle_ormqr 设计文档
22
3- | API 名称 | paddle.ormqr |
4- | ------------------------------------------------------------------- | --------- -------------------------------- |
5- | 提交作者 ` <input type="checkbox" class="rowselector hidden"> ` | Chen-Lun-Hao |
6- | 提交时间 ` <input type="checkbox" class="rowselector hidden"> ` | 2024-03-27 |
7- | 版本号 | V2.0 |
8- | 依赖飞桨版本 ` <input type="checkbox" class="rowselector hidden"> ` | develop |
9- | 文件名 | 20240326_api_design_for_ormqr.md ` <br> ` |
3+ | API 名称 | paddle.ormqr |
4+ | ------------ | -------------------------------- |
5+ | 提交作者 | Chen-Lun-Hao |
6+ | 提交时间 | 2024-03-27 |
7+ | 版本号 | V2.0 |
8+ | 依赖飞桨版本 | develop |
9+ | 文件名 | 20240326_api_design_for_ormqr.md |
1010
1111# 一、概述
1212
3737
3838## PyTorch
3939
40- PyTorch 中有 API ` torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor ` 以及对应的 ` torch.Tensor.ormqr `
40+ PyTorch 中有 API ` torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensor ` 以及对应的 ` torch.Tensor.ormqr `
4141
4242其介绍为:
4343
@@ -319,15 +319,11 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
319319
320320# 四、对比分析
321321
322- 对比 PyTorch 与 MindSpore:
323-
324- - 实现方式不同
325-
326- PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。
322+ 在 Pytorch 以及 MindSpore 框架中,他们对于 ormqr 算子的实现方式不同,Pytorch 中使用的是 c++ 实现,而 MindSpore 中使用的是 python 实现。
327323
328324# 五、设计思路与实现方案
329325
330- paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 ` input ` 与 `othrt` 具有相同的 `ndim` ,因此,不需要使用 `decrease_axes` 等参数 。
326+ paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` 。
331327
332328# # 命名与参数设计
333329
@@ -342,7 +338,7 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
342338- input : (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。
343339- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
344340- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
345- - left: (bool , 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True 。
341+ - left: (bool , 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) ∗ other ,否则,计算顺序为 other \ * op(Q)。默认值:True。
346342- transpose: (bool , 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False 。
347343
348344# # 底层 OP 设计
@@ -351,10 +347,11 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
351347
352348# 六、测试和验收的考量
353349
354- - GPU 测试场景
355- - 支持各种 Tensor
356- - 需要检查计算正确性
357- - 需要检查多维的情况
350+ - 支持 CPU 、GPU 测试场景
351+ - 支持动态图以及静态图
352+ - 支持各种 Tensor,如:float32, float64, complex64, complex128
353+ - 通过对比 mindspore 框架中的 ormqr 算子输出,计算结果是否一致
354+ - 需要检查二到三维矩阵的计算情况
358355
359356# 七、可行性分析和排期规划
360357
@@ -369,4 +366,5 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
369366# 附件及参考资料
370367
371368[【Hackathon 6th No.4】为 Paddle 新增 ormqr API ](https:// github.com/ PaddlePaddle/ community/ pull/ 668 )
372- [PyTorch slice_scatter 文档](https:// pytorch.org/ docs/ stable/ generated/ torch.slice_scatter.html)
369+
370+ [PyTorch ormqr 文档](https:// pytorch.org/ docs/ stable/ generated/ torch.ormqr.html)
0 commit comments