Skip to content

Commit 02dcc00

Browse files
committed
RFC
1 parent eac6003 commit 02dcc00

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# paddle_ormqr 设计文档
2+
3+
| API 名称 | paddle.ormqr |
4+
| ----------------------------------------------------------------- | --------------------------------------- |
5+
| 提交作者 `<input type="checkbox" class="rowselector hidden">` | Chen-Lun-Hao |
6+
| 提交时间 `<input type="checkbox" class="rowselector hidden">` | 2024-03-27 |
7+
| 版本号 | V2.0 |
8+
| 依赖飞桨版本 `<input type="checkbox" class="rowselector hidden">` | develop |
9+
| 文件名 | 20240326_api_design_for_ormqr.md `<br>` |
10+
11+
# 一、概述
12+
13+
## 1、相关背景
14+
15+
为了提升飞桨 API 丰富度,需要为飞桨扩充 API `paddle.ormqr`
16+
17+
本 API 属于飞桨开源个人贡献赛 API 开发任务[No.28:为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/Paddle/issues/62905)的任务。
18+
19+
## 2、功能目标
20+
21+
计算一个普通矩阵与 Householder 矩阵的乘积。计算维度为(m, n)的矩阵 C(由 other 给出)和一个矩阵 Q 的乘积, 其中 Q 由 Householder 反射系数 (x, tau) 表示。
22+
23+
预期该 API 支持
24+
25+
- paddle.linalg.ormqr 作为独立的函数调用
26+
- Tensor.ormqr 作为 Tensor 的方法使用
27+
28+
## 3、意义
29+
30+
为飞桨增加普通矩阵与指定矩阵的乘积的计算方式,提升飞桨 API 丰富度。
31+
32+
# 二、飞桨现状
33+
34+
目前飞桨缺少相关功能实现
35+
36+
# 三、业内方案调研
37+
38+
## PyTorch
39+
40+
PyTorch 中有 API `torch.ormqr(input, tau, other, left=True, transpose=False, *, out=None) → Tensoor` 以及对应的 `torch.Tensor.ormqr`
41+
42+
其介绍为:
43+
44+
> Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. Multiplies a m×n matrix C (given by other) with a matrix Q, where Q is represented using Householder reflectors (input, tau).
45+
46+
> 参数表为:
47+
48+
- `Tensor` input: tensor of shape (_, mn, k) where _ is zero or more batch dimensions and mn equals to m or n depending on the left.
49+
- `Tensor` tau: tensor of shape (_, min(mn, k)) where _ is zero or more batch dimensions.
50+
- `Tensor` other: tensor of shape (_, m, n) where _ is zero or more batch dimensions.
51+
- `bool` left: controls the order of multiplication.
52+
- `bool` transpose: controls whether the matrix Q is conjugate transposed or not.
53+
54+
### 实现
55+
56+
PyTorch 在 2.2 版本给出的 API 中,其默认后端 Inductor 针对 `ormqr`操作进行实现的代码如下,具体代码可以参考[BatchLinearAlgebraKernel.cpp](https://github.com/pytorch/pytorch/blob/99c822c0ba747fad8528ff6b57712abdbdc2c093/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp#L2710)
57+
58+
```python
59+
template <typename scalar_t>
60+
void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
61+
#if !AT_BUILD_WITH_LAPACK()
62+
TORCH_CHECK(false, "Calling torch.ormqr on a CPU tensor requires compiling ",
63+
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
64+
#else
65+
using value_t = typename c10::scalar_value_type<scalar_t>::type;
66+
67+
char side = left ? 'L' : 'R';
68+
char trans = transpose ? (input.is_complex() ? 'C' : 'T') : 'N';
69+
70+
auto input_data = input.const_data_ptr<scalar_t>();
71+
auto tau_data = tau.const_data_ptr<scalar_t>();
72+
auto other_data = other.data_ptr<scalar_t>();
73+
74+
auto input_matrix_stride = matrixStride(input);
75+
auto other_matrix_stride = matrixStride(other);
76+
auto tau_stride = tau.size(-1);
77+
auto batch_size = batchCount(input);
78+
auto m = other.size(-2);
79+
auto n = other.size(-1);
80+
auto k = tau.size(-1);
81+
auto lda = std::max<int64_t>(1, left ? m : n);
82+
auto ldc = std::max<int64_t>(1, m);
83+
int info = 0;
84+
85+
// LAPACK's requirement
86+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((left ? m : n) >= k);
87+
88+
// Query for the optimal size of the workspace tensor
89+
int lwork = -1;
90+
scalar_t wkopt;
91+
lapackOrmqr<scalar_t>(side, trans, m, n, k, const_cast<scalar_t*>(input_data), lda, const_cast<scalar_t*>(tau_data), other_data, ldc, &wkopt, lwork, &info);
92+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
93+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
94+
Tensor work = at::empty({lwork}, input.options());
95+
96+
for (const auto i : c10::irange(batch_size)) {
97+
const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
98+
scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
99+
const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
100+
101+
// now compute the actual result
102+
lapackOrmqr<scalar_t>(
103+
side, trans, m, n, k,
104+
const_cast<scalar_t*>(input_working_ptr), lda,
105+
const_cast<scalar_t*>(tau_working_ptr),
106+
other_working_ptr, ldc,
107+
work.data_ptr<scalar_t>(), lwork, &info);
108+
109+
// info from lapackOrmqr only reports if the i-th parameter is wrong
110+
// so we don't need to check it all the time
111+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
112+
}
113+
#endif
114+
}
115+
116+
// This is a type dispatching helper function for 'apply_ormqr'
117+
void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
118+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "ormqr_cpu", [&]{
119+
apply_ormqr<scalar_t>(input, tau, other, left, transpose);
120+
});
121+
}
122+
```
123+
124+
## MindSpore
125+
126+
MindSpore 中有 `mindspore.ops.orqmr` 此接口:
127+
128+
- `mindspore.ops.ormqr(input, tau, other, left=True, transpose=False)`
129+
130+
其实现代码:
131+
132+
https://www.mindspore.cn/docs/zh-CN/master/_modules/mindspore/ops/function/math_func.html#ormqr
133+
134+
```python
135+
def _get_cache_prim(cls: Primitive) -> Primitive:
136+
"""
137+
Wrapper function, get a primitive by it's all args.
138+
139+
Args:
140+
cls (Primitive): The Primitive need be wrapped.
141+
142+
Returns:
143+
Function, a new function with return a primitive by it's all args.
144+
145+
Examples:
146+
>>> # Example1:
147+
>>> from mindspore.ops._primitive_cache import _get_cache_prim
148+
>>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32))
149+
>>> axis = [0, 1]
150+
>>> p=2
151+
>>> keep_dims=False
152+
>>> epsilon=1e-12
153+
>>> _lp_norm = _get_cache_prim(P.LpNorm)(axis, p, keep_dims, epsilon)
154+
>>> output = _lp_norm(input_x)
155+
>>> print(output)
156+
[ 9.165152 10.954452]
157+
>>> # Example2:
158+
>>> from mindspore.ops._primitive_cache import _get_cache_prim
159+
>>> input_x = Tensor(np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]).astype(np.float32))
160+
>>> axis = [0, 1]
161+
>>> _lp_norm = _get_cache_prim(P.LpNorm)(axis, 2, keep_dims=False, epsilon=1e-12)
162+
>>> output = _lp_norm(input_x)
163+
>>> print(output)
164+
[ 9.165152 10.954452]
165+
"""
166+
167+
def _new_prim_for_graph(*args, **kwargs) -> Primitive:
168+
return cls(*args, **kwargs)
169+
170+
def _get_cache_prim_for_pynative(*args, **kwargs) -> Primitive:
171+
"""Get a primitive singleton by it's all args."""
172+
global _PRIM_CACHE
173+
key = (str(cls),)
174+
str_args = [str(arg) for arg in args]
175+
key += tuple(str_args)
176+
for attr_name in kwargs:
177+
attr_value = kwargs.get(attr_name)
178+
key += (attr_name + ":" + str(attr_value),)
179+
# Note: The key must be a str.
180+
key = str(key)
181+
if key not in _PRIM_CACHE:
182+
prim = Primitive.__new__(cls, *args, **kwargs)
183+
# Only init once.
184+
prim.__init__(*args, **kwargs)
185+
_PRIM_CACHE[key] = prim
186+
return _PRIM_CACHE.get(key)
187+
188+
if _is_need_compile(_temp_func): # @jit.cond: True
189+
return _new_prim_for_graph
190+
return _get_cache_prim_for_pynative
191+
192+
193+
[文档]def ormqr(input, tau, other, left=True, transpose=False):
194+
r"""
195+
Calculates two matrices multiplication of a product of a general matrix with Householder matrices.
196+
Calculates the product of a matrix C(given by `other`) with dimensions (m, n) and a matrix Q which is represented
197+
using Householder reflectors (`input`, `tau`). Returns a Tensor.
198+
199+
Args:
200+
input (Tensor): Tensor of shape :math:`(*, mn, k)`, when `left` is True, mn equals to m,
201+
otherwise, mn equals to n. And `*` is zero or more batch dimensions.
202+
tau (Tensor): Tensor of shape :math:`(*, min(mn, k))` where `*` is zero or more batch dimensions,
203+
and its type is the same as `input`.
204+
other (Tensor): Tensor of shape :math:`(*, m, n)` where `*` is zero or more batch dimensions,
205+
and its type is the same as `input`.
206+
left (bool, optional): determines the order of multiplication. If True, computes op(Q) \* `other` ,
207+
otherwise, compute `other` \* op(Q). Default: ``True`` .
208+
transpose (bool, optional): If True, the matrix Q is conjugate transposed,
209+
otherwise, not conjugate transposing matrix Q. Default: ``False`` .
210+
211+
Returns:
212+
Tensor, with the same type and shape as `other`.
213+
214+
Raises:
215+
TypeError: If `input` or `tau` or `other` is not Tensor.
216+
TypeError: If dtype of `input` or `tau` or `other` is not one of: float64, float32, complex64, complex128.
217+
ValueError: If the dimension of `input` or `other` is less than 2D.
218+
ValueError: If rank(`input`) - rank(`tau`) != 1.
219+
ValueError: If tau.shape[:-2] != input.shape[:-2]
220+
ValueError: If other.shape[:-2] != input.shape[:-2]
221+
ValueError: If left == true, other.shape[-2] < tau.shape[-1].
222+
ValueError: If left == true, other.shape[-2] != input.shape[-2].
223+
ValueError: If left == false, other.shape[-1] < tau.shape[-1].
224+
ValueError: If left == false, other.shape[-1] != input.shape[-2].
225+
226+
Supported Platforms:
227+
``GPU``
228+
229+
Examples:
230+
>>> import mindspore
231+
>>> import numpy as np
232+
>>> from mindspore import Tensor, ops
233+
>>> input = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]),
234+
... mindspore.float32)
235+
>>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32)
236+
>>> other = Tensor(np.array([[-114.6, 10.9, 1.1],
237+
... [-0.304, 38.07, 69.38],
238+
... [-0.45, -0.17, 62]]), mindspore.float32)
239+
>>> output = ops.ormqr(input, tau, other)
240+
>>> print(output)
241+
[[ 63.82713 -13.823125 -116.28614 ]
242+
[ -53.659264 -28.157839 -70.42702 ]
243+
[ -79.54292 24.00183 -41.34253 ]]
244+
"""
245+
246+
ormqr_ = _get_cache_prim(Ormqr)(left, transpose)
247+
return ormqr_(input, tau, other)
248+
249+
```
250+
251+
# 四、对比分析
252+
253+
对比 PyTorch 与 MindSpore:
254+
255+
- 实现方式不同
256+
257+
PyTorch 通过 c++ 实现;MindSpore 通过 python 实现。
258+
259+
# 五、设计思路与实现方案
260+
261+
paddle 目前的算子已经支持矩阵的转置,行列计算等操作,因此,可以使用 paddle 已有算子实现 `ormqr` ,由于要求输入 `input``othrt` 具有相同的 `ndim`,因此,不需要使用 `decrease_axes` 等参数。
262+
263+
## 命名与参数设计
264+
265+
添加 Python API:
266+
267+
```python
268+
paddle.orqmr(input, tau, other, left=True, transpose=False)
269+
```
270+
271+
参数表:
272+
273+
- input: (Tensor) shape(\*,mn,k),当 left 为 True 时, mn 的值等于 m,否则 mn 的值等于 n。 \*表示 Tensor 在轴 0 上的长度为 0 或者大于 0。
274+
- tau: (Tensor) shape(\*,min(mn,k)),其中 \_ 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
275+
- other: (Tensor) shape(\*,m,n),其中 \* 表示 Tensor 在轴 0 上的长度为 0 或者大于 0,其类型与 input 相同。
276+
- left: (bool, 可选) 决定了矩阵乘积运算的顺序。如果 left 为 True ,计算顺序为 op(Q) x other ,否则,计算顺序为 other x op(Q)。默认值:True
277+
- transpose: (bool, 可选) 如果为 True ,对矩阵 Q 进行共轭转置变换,否则,不对矩阵 Q 进行共轭转置变换。默认值: False
278+
279+
## 底层 OP 设计
280+
281+
不涉及底层 OP
282+
283+
# 六、测试和验收的考量
284+
285+
- GPU 测试场景
286+
- 支持各种 Tensor
287+
- 需要检查计算正确性
288+
- 需要检查多维的情况
289+
290+
# 七、可行性分析和排期规划
291+
292+
有业内方案实现作为参考,相关 PythonAPI 均有实现,可以在开源贡献个人挑战赛期间完成。
293+
294+
# 八、影响面
295+
296+
对其他模块暂无影响
297+
298+
# 名词解释
299+
300+
# 附件及参考资料
301+
302+
[【Hackathon 6th No.4】为 Paddle 新增 ormqr API](https://github.com/PaddlePaddle/community/pull/668)
303+
[PyTorch slice_scatter 文档](https://pytorch.org/docs/stable/generated/torch.slice_scatter.html)

0 commit comments

Comments
 (0)