11# paddle_ormqr 设计文档
22
3- | API 名称 | paddle.ormqr |
4- | ----------------------------------------------------------------- | ------- -------------------------------- |
5- | 提交作者 ` <input type="checkbox" class="rowselector hidden"> ` | Chen-Lun-Hao |
6- | 提交时间 ` <input type="checkbox" class="rowselector hidden"> ` | 2024-03-27 |
7- | 版本号 | V2.0 |
8- | 依赖飞桨版本 ` <input type="checkbox" class="rowselector hidden"> ` | develop |
9- | 文件名 | 20240326_api_design_for_ormqr.md ` <br> ` |
3+ | API 名称 | paddle.ormqr |
4+ | ------------ | -------------------------------- |
5+ | 提交作者 | Chen-Lun-Hao |
6+ | 提交时间 | 2024-03-27 |
7+ | 版本号 | V2.0 |
8+ | 依赖飞桨版本 | develop |
9+ | 文件名 | 20240326_api_design_for_ormqr.md |
1010
1111# 一、概述
1212
3737
3838## PyTorch
3939
40- PyTorch 中有 API ` torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor ` 以及对应的 ` torch.Tensor.ormqr `
40+ PyTorch 中有 API ` torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensor ` 以及对应的 ` torch.Tensor.ormqr `
4141
4242其介绍为:
4343
@@ -132,6 +132,75 @@ MindSpore 中有 `mindspore.ops.orqmr` 此接口:
132132https:// www.mindspore.cn/ docs/ zh- CN / master/ _modules/ mindspore/ ops/ function/ math_func.html# ormqr
133133
134134```python
135+ class Ormqr(Primitive):
136+ r """
137+ Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
138+ Multiplies a( m, n) matrix C ( given by other) with a matrix Q, where Q is represented using Householder
139+ reflectors ( x, tau) , which is the output of geqrf( ) .
140+
141+ Refer to :func:`mindspore. ops. ormqr` for more details.
142+
143+ .. warning::
144+ This is an experimental API that is subject to change or deletion.
145+
146+ Args:
147+ left ( bool, optional) : controls the order of multiplication. If ``True`` , compute op( Q) * C.
148+ If ``False`` , compute C* op( Q) . Default: ``True`` .
149+ transpose( bool, optional) : controls whether the matrix Q is conjugate transposed or not. Default: ``False`` .
150+
151+ Inputs:
152+ - ** x** ( Tensor) - Tensor of shape :math:`( * , mn, k) ` where the value of mn depending on `left`,
153+ When `left` is ``True``, the value of mn is equal to m; otherwise, the value of mn is equal to n.
154+ and `* ` is zero or more batch dimensions.
155+ - ** tau** ( Tensor) - Tensor of shape :math:`( * , min( mn, k)) ` where `* ` is zero or more batch dimensions,
156+ and its type is the same as `x`.
157+ - ** other** ( Tensor) - Tensor of shape :math:`( * , m, n) ` where `* ` is zero or more batch dimensions,
158+ and its type is the same as `x`.
159+
160+ Outputs:
161+ - ** y** ( Tensor) - the output Tensor, has the same shape and data type as `other`.
162+
163+ Raises:
164+ TypeError: If `x` or `tau` or `other` is not Tensor.
165+ TypeError: If dtype of `x` or `tau` or `other` is not one of: float64, float32, complex64, complex128.
166+ ValueError: If `x` or `other` is less than 2D.
167+ ValueError: If rank( x) - rank( tau) != 1.
168+ ValueError: If tau. shape[:-2 ] != x. shape[:-2 ]
169+ ValueError: If other. shape[:-2 ] != x. shape[:-2 ]
170+ ValueError: If left == True, other. shape[-2 ] < tau. shape[-1 ].
171+ ValueError: If left == True, other. shape[-2 ] != x. shape[-2 ].
172+ ValueError: If left == False, other. shape[-1 ] < tau. shape[-1 ].
173+ ValueError: If left == False, other. shape[-1 ] != x. shape[-2 ].
174+
175+ Supported Platforms:
176+ ``GPU``
177+
178+ Examples:
179+ >>> import mindspore
180+ >>> import numpy as np
181+ >>> from mindspore import Tensor, ops
182+ >>> x = Tensor( np. array( [[-114.6, 10.9, 1.1 ], [-0.304, 38.07, 69.38 ], [-0.45, -0.17, 62 ]]) , mindspore. float32)
183+ >>> tau = Tensor( np. array( [1.55, 1.94, 3.0 ]) , mindspore. float32)
184+ >>> other = Tensor( np. array( [[-114.6, 10.9, 1.1 ],
185+ ... [-0.304, 38.07, 69.38 ],
186+ ... [-0.45, -0.17, 62 ]]) , mindspore. float32)
187+ >>> net = ops. Ormqr( )
188+ >>> y = net( x, tau, other)
189+ >>> print( y)
190+ [[ 63.82713 -13.823125 -116.28614 ]
191+ [ -53.659264 -28.157839 -70.42702 ]
192+ [ -79.54292 24.00183 -41.34253 ]]
193+ """
194+
195+ @ prim_attr_register
196+ def __init__ (self , left = True , transpose = False ):
197+ """ Initialize Ormqr"""
198+ self .init_prim_io_names(inputs = [' x' , ' tau' , ' other' ], outputs = [' y' ])
199+ self .left = validator.check_value_type(' left' , left, [bool ], self .name)
200+ self .transpose = validator.check_value_type(' transpose' , transpose, [bool ], self .name)
201+ self .add_prim_attr(' left' , self .left)
202+ self .add_prim_attr(' transpose' , self .transpose)
203+
135204def _get_cache_prim(cls : Primitive) -> Primitive:
136205 """
137206 Wrapper function, get a primitive by it's all args.
@@ -250,15 +319,11 @@ def _get_cache_prim(cls: Primitive) -> Primitive:
250319
251320# 四、对比分析
252321
253- 对比 PyTorch 与 MindSpore:
254-
255- - 实现方式不同
256-
257- PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。
322+ 在 Pytorch 以及 MindSpore 框架中,他们对于 ormqr 算子的实现方式不同,Pytorch 中使用的是 c++ 实现,而 MindSpore 中使用的是 python 实现。
258323
259324# 五、设计思路与实现方案
260325
261- paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 ` input ` 与 `othrt` 具有相同的 `ndim` ,因此,不需要使用 `decrease_axes` 等参数 。
326+ paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` 。
262327
263328# # 命名与参数设计
264329
@@ -273,7 +338,7 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
273338- input : (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。
274339- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
275340- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
276- - left: (bool , 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True 。
341+ - left: (bool , 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) ∗ other ,否则,计算顺序为 other \ * op(Q)。默认值:True。
277342- transpose: (bool , 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False 。
278343
279344# # 底层 OP 设计
@@ -282,10 +347,11 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
282347
283348# 六、测试和验收的考量
284349
285- - GPU 测试场景
286- - 支持各种 Tensor
287- - 需要检查计算正确性
288- - 需要检查多维的情况
350+ - 支持 CPU 、GPU 测试场景
351+ - 支持动态图以及静态图
352+ - 支持各种 Tensor,如:float32, float64, complex64, complex128
353+ - 通过对比 mindspore 框架中的 ormqr 算子输出,计算结果是否一致
354+ - 需要检查二到三维矩阵的计算情况
289355
290356# 七、可行性分析和排期规划
291357
@@ -300,4 +366,5 @@ paddle.orqmr(input, tau, other, left=True, transpose=False)
300366# 附件及参考资料
301367
302368[【Hackathon 6th No.4】为 Paddle 新增 ormqr API ](https:// github.com/ PaddlePaddle/ community/ pull/ 668 )
303- [PyTorch slice_scatter 文档](https:// pytorch.org/ docs/ stable/ generated/ torch.slice_scatter.html)
369+
370+ [PyTorch ormqr 文档](https:// pytorch.org/ docs/ stable/ generated/ torch.ormqr.html)
0 commit comments