Skip to content

Commit 4d1397c

Browse files
committed
update rfc
1 parent 02dcc00 commit 4d1397c

File tree

1 file changed

+87
-20
lines changed

1 file changed

+87
-20
lines changed

rfcs/APIs/20240326_api_design_for_ormqr.md

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# paddle_ormqr 设计文档
22

3-
| API 名称 | paddle.ormqr |
4-
| ----------------------------------------------------------------- | --------------------------------------- |
5-
| 提交作者 `<input type="checkbox" class="rowselector hidden">` | Chen-Lun-Hao |
6-
| 提交时间 `<input type="checkbox" class="rowselector hidden">` | 2024-03-27 |
7-
| 版本号 | V2.0 |
8-
| 依赖飞桨版本 `<input type="checkbox" class="rowselector hidden">` | develop |
9-
| 文件名 | 20240326_api_design_for_ormqr.md `<br>` |
3+
| API 名称 | paddle.ormqr |
4+
| ------------ | -------------------------------- |
5+
| 提交作者 | Chen-Lun-Hao |
6+
| 提交时间 | 2024-03-27 |
7+
| 版本号 | V2.0 |
8+
| 依赖飞桨版本 | develop |
9+
| 文件名 | 20240326_api_design_for_ormqr.md |
1010

1111
# 一、概述
1212

@@ -37,7 +37,7 @@
3737

3838
## PyTorch
3939

40-
PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor` 以及对应的 `torch.Tensor.ormqr`
40+
PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensor` 以及对应的 `torch.Tensor.ormqr`
4141

4242
其介绍为:
4343

@@ -132,6 +132,75 @@ MindSpore 中有 `mindspore.ops.orqmr` 此接口:
132132
https://www.mindspore.cn/docs/zh-CN/master/_modules/mindspore/ops/function/math_func.html#ormqr
133133

134134
```python
135+
class Ormqr(Primitive):
136+
r"""
137+
Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
138+
Multiplies a(m, n) matrix C (given by other) with a matrix Q, where Q is represented using Householder
139+
reflectors (x, tau), which is the output of geqrf().
140+
141+
Refer to :func:`mindspore.ops.ormqr` for more details.
142+
143+
.. warning::
144+
This is an experimental API that is subject to change or deletion.
145+
146+
Args:
147+
left (bool, optional): controls the order of multiplication. If ``True`` , compute op(Q)*C.
148+
If ``False`` , compute C*op(Q). Default: ``True`` .
149+
transpose(bool, optional): controls whether the matrix Q is conjugate transposed or not.Default: ``False`` .
150+
151+
Inputs:
152+
- **x** (Tensor) - Tensor of shape :math:`(*, mn, k)` where the value of mn depending on `left`,
153+
When `left` is ``True``, the value of mn is equal to m; otherwise, the value of mn is equal to n.
154+
and `*` is zero or more batch dimensions.
155+
- **tau** (Tensor) - Tensor of shape :math:`(*, min(mn, k))` where `*` is zero or more batch dimensions,
156+
and its type is the same as `x`.
157+
- **other** (Tensor) - Tensor of shape :math:`(*, m, n)` where `*` is zero or more batch dimensions,
158+
and its type is the same as `x`.
159+
160+
Outputs:
161+
- **y** (Tensor) - the output Tensor, has the same shape and data type as `other`.
162+
163+
Raises:
164+
TypeError: If `x` or `tau` or `other` is not Tensor.
165+
TypeError: If dtype of `x` or `tau` or `other` is not one of: float64, float32, complex64, complex128.
166+
ValueError: If `x` or `other` is less than 2D.
167+
ValueError: If rank(x) - rank(tau) != 1.
168+
ValueError: If tau.shape[:-2] != x.shape[:-2]
169+
ValueError: If other.shape[:-2] != x.shape[:-2]
170+
ValueError: If left == True, other.shape[-2] < tau.shape[-1].
171+
ValueError: If left == True, other.shape[-2] != x.shape[-2].
172+
ValueError: If left == False, other.shape[-1] < tau.shape[-1].
173+
ValueError: If left == False, other.shape[-1] != x.shape[-2].
174+
175+
Supported Platforms:
176+
``GPU``
177+
178+
Examples:
179+
>>> import mindspore
180+
>>> import numpy as np
181+
>>> from mindspore import Tensor, ops
182+
>>> x = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]), mindspore.float32)
183+
>>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32)
184+
>>> other = Tensor(np.array([[-114.6, 10.9, 1.1],
185+
... [-0.304, 38.07, 69.38],
186+
... [-0.45, -0.17, 62]]), mindspore.float32)
187+
>>> net = ops.Ormqr()
188+
>>> y = net(x, tau, other)
189+
>>> print(y)
190+
[[ 63.82713 -13.823125 -116.28614 ]
191+
[ -53.659264 -28.157839 -70.42702 ]
192+
[ -79.54292 24.00183 -41.34253 ]]
193+
"""
194+
195+
@prim_attr_register
196+
def __init__(self, left=True, transpose=False):
197+
"""Initialize Ormqr"""
198+
self.init_prim_io_names(inputs=['x', 'tau', 'other'], outputs=['y'])
199+
self.left = validator.check_value_type('left', left, [bool], self.name)
200+
self.transpose = validator.check_value_type('transpose', transpose, [bool], self.name)
201+
self.add_prim_attr('left', self.left)
202+
self.add_prim_attr('transpose', self.transpose)
203+
135204
def _get_cache_prim(cls: Primitive) -> Primitive:
136205
"""
137206
Wrapper function, get a primitive by it's all args.
@@ -250,15 +319,11 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
250319

251320
# 四、对比分析
252321

253-
对比 PyTorch 与 MindSpore:
254-
255-
- 实现方式不同
256-
257-
PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。
322+
在 Pytorch 以及 MindSpore 框架中,他们对于 ormqr 算子的实现方式不同,Pytorch 中使用的是 c++实现,而 MindSpore 中使用的是 python 实现。
258323

259324
# 五、设计思路与实现方案
260325

261-
paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 `input``othrt` 具有相同的 `ndim`,因此,不需要使用 `decrease_axes` 等参数
326+
paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr`
262327

263328
## 命名与参数设计
264329

@@ -273,7 +338,7 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
273338
- input: (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。
274339
- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
275340
- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
276-
- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True
341+
- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) other ,否则,计算顺序为 other \* op(Q)。默认值:True。
277342
- transpose: (bool, 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False
278343

279344
## 底层 OP 设计
@@ -282,10 +347,11 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
282347

283348
# 六、测试和验收的考量
284349

285-
- GPU 测试场景
286-
- 支持各种 Tensor
287-
- 需要检查计算正确性
288-
- 需要检查多维的情况
350+
- 支持 CPUGPU 测试场景
351+
- 支持动态图以及静态图
352+
- 支持各种 Tensor,如:float32, float64, complex64, complex128
353+
- 通过对比 mindspore 框架中的 ormqr 算子输出,计算结果是否一致
354+
- 需要检查二到三维矩阵的计算情况
289355

290356
# 七、可行性分析和排期规划
291357

@@ -300,4 +366,5 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
300366
# 附件及参考资料
301367

302368
[【Hackathon 6th No.4】为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/community/pull/668)
303-
[PyTorch slice_scatter 文档](https://pytorch.org/docs/stable/generated/torch.slice_scatter.html)
369+
370+
[PyTorch ormqr 文档](https://pytorch.org/docs/stable/generated/torch.ormqr.html)

0 commit comments

Comments
 (0)