Skip to content

Commit d19633d

Browse files
committed
update ormqr rfc
1 parent 3c40046 commit d19633d

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

rfcs/APIs/20240326_api_design_for_ormqr.md

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# paddle_ormqr 设计文档
22

3-
| API 名称 | paddle.ormqr |
4-
| ------------------------------------------------------------------- | ----------------------------------------- |
5-
| 提交作者 `<input type="checkbox" class="rowselector hidden">` | Chen-Lun-Hao |
6-
| 提交时间 `<input type="checkbox" class="rowselector hidden">` | 2024-03-27 |
7-
| 版本号 | V2.0 |
8-
| 依赖飞桨版本 `<input type="checkbox" class="rowselector hidden">` | develop |
9-
| 文件名 | 20240326_api_design_for_ormqr.md `<br>` |
3+
| API 名称 | paddle.ormqr |
4+
| ------------ | -------------------------------- |
5+
| 提交作者 | Chen-Lun-Hao |
6+
| 提交时间 | 2024-03-27 |
7+
| 版本号 | V2.0 |
8+
| 依赖飞桨版本 | develop |
9+
| 文件名 | 20240326_api_design_for_ormqr.md |
1010

1111
# 一、概述
1212

@@ -37,7 +37,7 @@
3737

3838
## PyTorch
3939

40-
PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor` 以及对应的 `torch.Tensor.ormqr`
40+
PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensor` 以及对应的 `torch.Tensor.ormqr`
4141

4242
其介绍为:
4343

@@ -319,15 +319,11 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
319319

320320
# 四、对比分析
321321

322-
对比 PyTorch 与 MindSpore:
323-
324-
- 实现方式不同
325-
326-
PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。
322+
在 Pytorch 以及 MindSpore 框架中,他们对于 ormqr 算子的实现方式不同,Pytorch 中使用的是 c++实现,而 MindSpore 中使用的是 python 实现。
327323

328324
# 五、设计思路与实现方案
329325

330-
paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 `input``othrt` 具有相同的 `ndim`,因此,不需要使用 `decrease_axes` 等参数
326+
paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr`
331327

332328
## 命名与参数设计
333329

@@ -342,7 +338,7 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
342338
- input: (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。
343339
- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
344340
- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
345-
- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True
341+
- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) other ,否则,计算顺序为 other \* op(Q)。默认值:True。
346342
- transpose: (bool, 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False
347343

348344
## 底层 OP 设计
@@ -351,10 +347,11 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
351347

352348
# 六、测试和验收的考量
353349

354-
- GPU 测试场景
355-
- 支持各种 Tensor
356-
- 需要检查计算正确性
357-
- 需要检查多维的情况
350+
- 支持 CPUGPU 测试场景
351+
- 支持动态图以及静态图
352+
- 支持各种 Tensor,如:float32, float64, complex64, complex128
353+
- 通过对比 mindspore 框架中的 ormqr 算子输出,计算结果是否一致
354+
- 需要检查二到三维矩阵的计算情况
358355

359356
# 七、可行性分析和排期规划
360357

@@ -369,4 +366,5 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
369366
# 附件及参考资料
370367

371368
[【Hackathon 6th No.4】为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/community/pull/668)
372-
[PyTorch slice_scatter 文档](https://pytorch.org/docs/stable/generated/torch.slice_scatter.html)
369+
370+
[PyTorch ormqr 文档](https://pytorch.org/docs/stable/generated/torch.ormqr.html)

0 commit comments

Comments
 (0)