Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.36】 为 Paddle 新增 matrix_exp API -part #59715

Merged
merged 12 commits into from
Dec 21, 2023

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Dec 5, 2023

PR types

New features

PR changes

APIs

Description

【Hackathon 5th No.36】 为 Paddle 新增 matrix_exp API

RFC: PaddlePaddle/community#674 PaddlePaddle/community#775

此实现基于上述的 RFC ~ RFC 已经调研的十分详细了,所以没有再编写 RFC ~

这里的实现与 Tensorflow/Eigen 思路类似,都是使用 pade 近似 ~

涉及文件:

  • python/paddle/tensor/linalg.py 算法实现
  • test/legacy_test/test_linalg_matrix_exp.py 单测

cc @zade23 看到您这边已经很久没有更新 RFC 与 PR 了,特此越俎代庖了,如果您那边已经实现好了,还请告知 ~

非常感谢!

Copy link

paddle-bot bot commented Dec 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Dec 5, 2023
@zade23
Copy link
Contributor

zade23 commented Dec 5, 2023

@megemini
说来惭愧,由于算子开发经验不足和技术能力受限,我并没有完成在RFC之后的开发,感谢您完成后续的工作,借此也观摩学习一下您的实现思路~

@luotao1
Copy link
Contributor

luotao1 commented Dec 6, 2023

@megemini 顺师傅,因为RFC还没合入,能不能根据comment意见帮忙修改下RFC呢

@megemini
Copy link
Contributor Author

megemini commented Dec 6, 2023

@luotao1
Copy link
Contributor

luotao1 commented Dec 6, 2023

image 需要补一下单测来通过覆盖率

@megemini
Copy link
Contributor Author

megemini commented Dec 6, 2023

需要补一下单测来通过覆盖率

涛姐效率太高了 🤣🤣🤣 ~~~ 先别急,还有 docstring 、单测都要补充完善 ~ 算法也要看一下还有没有可以优化的地方 ~

等完成了再@大家 ~

非常感谢!👍👍👍

@megemini
Copy link
Contributor Author

megemini commented Dec 7, 2023

Update 20231207

  • 增加算法对于 0d/scalar 的处理
  • 统一静态/动态图的数据类型处理
  • 增加 error 测试
  • 增加 docstring

@luotao1 @cxxly

测试的时候发现,pir 下的 paddle.static.nn.cond 好像有问题,参考如下的测试程序(官方示例):

import numpy as np
import scipy

import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

@test_with_pir_api
def test_static():
    paddle.enable_static()
    places = [base.CPUPlace()]
    if core.is_compiled_with_cuda():
        places.append(base.CUDAPlace(0))

    for place in places:
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
            def true_func():
                return paddle.full(shape=[1, 2], dtype='int32',
                                fill_value=1), paddle.full(shape=[2, 3],
                                                            dtype='bool',
                                                            fill_value=True)


            def false_func():
                return paddle.full(shape=[3, 4], dtype='float32',
                                fill_value=3), paddle.full(shape=[4, 5],
                                                            dtype='int64',
                                                            fill_value=2)


            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)

            exe = paddle.static.Executor(place)

            res = exe.run(
                fetch_list=[ret],
            )
            print(res)

if __name__ == '__main__':
    test_static()

运行后输出:

aistudio@jupyter-942478-6602454:~/matrix_exp$ python test_cond.py 
I1207 13:58:48.306679 21584 program_interpreter.cc:212] New Executor is Running.
I1207 13:58:48.307299 21584 conditional_block_op.cc:98] [ControlFlow][ConditionalBlock] New Executor is Running.
[array([[1, 1]], dtype=int32), array([[ True,  True,  True],
       [ True,  True,  True]])]
Traceback (most recent call last):
  File "/home/aistudio/matrix_exp/test_cond.py", line 47, in <module>
    test_static()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/pir_utils.py", line 115, in impl
    func(*args, **kwargs)
  File "/home/aistudio/matrix_exp/test_cond.py", line 37, in test_static
    ret = paddle.static.nn.cond(pred, true_func, false_func)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/static/nn/control_flow.py", line 1254, in cond
    if_op.update_output()
RuntimeError: (PreconditionNotMet) The output[0] type of true block and false block must be equal.
  [Hint: Expected op.operand(i).type() == argument.output_types[i], but received op.operand(i).type():pd_op.tensor<3x4xf32> != argument.output_types[i]:pd_op.tensor<1x2xi32>.] (at /paddle/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc:91)

可以看到,旧 ir 的运行是正常的,而 pir 提示输出错误 ~

参考 https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/static/nn/cond_cn.html#cond

paddle.static.nn.cond 中说明:

true_fn 和 false_fn 返回的元组必须形状相同,但是里面的 Tensor 形状可以不同。

上面的输出是一个 tuple,含有两个 tensor:

  • true_func: 第一个是 shape=[1, 2], 第二个是 shape=[2, 3]
  • false_func:第一个是 shape=[3, 4], 第二个是 shape=[4, 5]

虽然每个 tensor 的 shape 不一样,但是都是长度为 2 的 tuple,所以应该没什么问题 ~

另外,在 matrix_exp 这个函数中使用的时候,即使所有 tensor 的 shape 一样,也会报错!所以问题可能还得往回追溯 ~

还请帮忙定位解决一下,不然 pir 下的测试跑不通 ~

非常感谢!

@winter-wang
Copy link
Contributor

winter-wang commented Dec 7, 2023

Update 20231207

  • 增加算法对于 0d/scalar 的处理
  • 统一静态/动态图的数据类型处理
  • 增加 error 测试
  • 增加 docstring

@luotao1 @cxxly

测试的时候发现,pir 下的 paddle.static.nn.cond 好像有问题,参考如下的测试程序(官方示例):

import numpy as np
import scipy

import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

@test_with_pir_api
def test_static():
    paddle.enable_static()
    places = [base.CPUPlace()]
    if core.is_compiled_with_cuda():
        places.append(base.CUDAPlace(0))

    for place in places:
        with paddle.static.program_guard(
            paddle.static.Program(), paddle.static.Program()
        ):
            def true_func():
                return paddle.full(shape=[1, 2], dtype='int32',
                                fill_value=1), paddle.full(shape=[2, 3],
                                                            dtype='bool',
                                                            fill_value=True)


            def false_func():
                return paddle.full(shape=[3, 4], dtype='float32',
                                fill_value=3), paddle.full(shape=[4, 5],
                                                            dtype='int64',
                                                            fill_value=2)


            x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
            y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
            pred = paddle.less_than(x=x, y=y, name=None)
            ret = paddle.static.nn.cond(pred, true_func, false_func)

            exe = paddle.static.Executor(place)

            res = exe.run(
                fetch_list=[ret],
            )
            print(res)

if __name__ == '__main__':
    test_static()

运行后输出:

aistudio@jupyter-942478-6602454:~/matrix_exp$ python test_cond.py 
I1207 13:58:48.306679 21584 program_interpreter.cc:212] New Executor is Running.
I1207 13:58:48.307299 21584 conditional_block_op.cc:98] [ControlFlow][ConditionalBlock] New Executor is Running.
[array([[1, 1]], dtype=int32), array([[ True,  True,  True],
       [ True,  True,  True]])]
Traceback (most recent call last):
  File "/home/aistudio/matrix_exp/test_cond.py", line 47, in <module>
    test_static()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/pir_utils.py", line 115, in impl
    func(*args, **kwargs)
  File "/home/aistudio/matrix_exp/test_cond.py", line 37, in test_static
    ret = paddle.static.nn.cond(pred, true_func, false_func)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/static/nn/control_flow.py", line 1254, in cond
    if_op.update_output()
RuntimeError: (PreconditionNotMet) The output[0] type of true block and false block must be equal.
  [Hint: Expected op.operand(i).type() == argument.output_types[i], but received op.operand(i).type():pd_op.tensor<3x4xf32> != argument.output_types[i]:pd_op.tensor<1x2xi32>.] (at /paddle/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc:91)

可以看到,旧 ir 的运行是正常的,而 pir 提示输出错误 ~

参考 https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/static/nn/cond_cn.html#cond

paddle.static.nn.cond 中说明:

true_fn 和 false_fn 返回的元组必须形状相同,但是里面的 Tensor 形状可以不同。

上面的输出是一个 tuple,含有两个 tensor:

  • true_func: 第一个是 shape=[1, 2], 第二个是 shape=[2, 3]
  • false_func:第一个是 shape=[3, 4], 第二个是 shape=[4, 5]

虽然每个 tensor 的 shape 不一样,但是都是长度为 2 的 tuple,所以应该没什么问题 ~

另外,在 matrix_exp 这个函数中使用的时候,即使所有 tensor 的 shape 一样,也会报错!所以问题可能还得往回追溯 ~

还请帮忙定位解决一下,不然 pir 下的测试跑不通 ~

非常感谢!

PIR当前暂时不支持if算子两个分支返回值shape不一致的情景(预计还需要一周左右的时间才会支持该情景)。 如果该场景不可避免,建议先在旧IR场景下验证。只需删除单测装饰符@test_with_pir_api 即可。

@luotao1
Copy link
Contributor

luotao1 commented Dec 8, 2023

@winter-wang 沟通:该PR可先在旧IR场景下验证。只需删除单测装饰符@test_with_pir_api 即可。
cc @cxxly 知晓

@luotao1
Copy link
Contributor

luotao1 commented Dec 11, 2023

test_linalg_matrix_exp (Timeout)

如果15秒内跑不完,可以适当扩大下时间(下面是例子120秒,实际按需扩大)
image

@megemini
Copy link
Contributor Author

@cxxly CI 应该没啥问题了,请评审 ~ 谢谢!

另外,单测里面对于 RTOL/ATOL 在 win32 降了测试精度,主要是 windows-inference 好像在 float32 下差了 5e-03,其他环境(包括 windows ci)都没有精度问题 ~ 我在本地 ubuntu 测试 float32 可以到 1e-06,float64 可以到 1e-09 ~ 不清楚是不是这个环境有啥特殊的?

else:
RTOL = {'float32': 1e-03, 'float64': 1e-05}
ATOL = {'float32': 1e-03, 'float64': 1e-05}

Copy link
Contributor

@cxxly cxxly Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于FP32/FP64来说,这个精度阈值设置过大,有没有参考标准,比如可以参考下PyTorch 单测是怎么设置的。

另外,对于这个单测中的实验数据,可以本地和PyTorch再对比下,看看误差多大,理论上 fp32 rtol<1e-6, fp64 rtol<1e-15

如果有问题,需要定位原因,是基础API误差(看了用到的一些API,根据经验应该概率比较小),还是计算逻辑某地地方疏忽了

Copy link
Contributor Author

@megemini megemini Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch 的测试在 test/test_linalg.py,有两个主要的精度测试逻辑:

  • test_linalg_matrix_exp_analytic 应该是 pade 近似的做法,里面用到了逆矩阵 torch.inverse,这里 atol=1e-3,单双精度一样的标准 ~
  • test_linalg_matrix_exp_compare_with_taylor 泰勒多项式的算法,这里 atol=1e-2,单双精度一样 ~

以上两者都是用 numpy 实现的 ~

我这里是用的 scipy 比对,应该比自己写靠谱点 ~~~ 🤣🤣🤣

以下是 torch 的测试:

    def test_linalg_matrix_exp_analytic(self, device, dtype):
        expm = torch.linalg.matrix_exp
        ...
        def run_test(*n):
            if dtype == torch.float:
                thetas = [
                    1.192092800768788e-07,  # deg 1
                    5.978858893805233e-04,  # deg 2
                    5.116619363445086e-02,  # deg 4
                    5.800524627688768e-01,  # deg 8
                    1.461661507209034e+00,  # deg 12
                    3.010066362817634e+00   # deg 18
                ]
            else:  # if torch.double
                thetas = [
                    2.220446049250313e-16,  # deg 1
                    2.580956802971767e-08,  # deg 2
                    3.397168839976962e-04,  # deg 4
                    4.991228871115323e-02,  # deg 8
                    2.996158913811580e-01,  # deg 12
                    1.090863719290036e+00   # deg 18
                ]

            # generate input
            q = gen_good_cond_number_matrices(*n)
            q_ = q.cpu().numpy()
            qinv = torch.inverse(q)
            qinv_ = qinv.cpu().numpy()
            d = torch.randn(n[:-1], dtype=dtype, device=device)
            x = torch.from_numpy(
                np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device)
            x_norm, _ = x.abs().sum(-2).max(-1)

            # test simple analytic whatever norm generated
            mexp = expm(x)
            mexp_analytic = np.matmul(
                q_,
                np.matmul(
                    torch.diag_embed(d.exp()).cpu().numpy(),
                    qinv_
                )
            )
            self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)

            # generate norms to test different degree expansions
            sample_norms = []
            for i in range(len(thetas) - 1):
                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]

            # matrices to equal norm
            for sample_norm in sample_norms:
                x_normalized = normalize_to_1_operator_norm(x, sample_norm)

                mexp = expm(x_normalized)
                mexp_analytic = np.matmul(
                    q_,
                    np.matmul(
                        torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(),
                        qinv_
                    )
                )
                self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
    def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype):
        ...
        def run_test(*n):
            degs = [1, 2, 4, 8, 12, 18]
            if dtype == torch.float:
                thetas = [
                    1.192092800768788e-07,  # deg 1
                    5.978858893805233e-04,  # deg 2
                    5.116619363445086e-02,  # deg 4
                    5.800524627688768e-01,  # deg 8
                    1.461661507209034e+00,  # deg 12
                    3.010066362817634e+00   # deg 18
                ]
            else:  # if torch.double
                thetas = [
                    2.220446049250313e-16,  # deg 1
                    2.580956802971767e-08,  # deg 2
                    3.397168839976962e-04,  # deg 4
                    4.991228871115323e-02,  # deg 8
                    2.996158913811580e-01,  # deg 12
                    1.090863719290036e+00   # deg 18
                ]

            # generate norms to test different degree expansions
            sample_norms = []
            for i in range(len(thetas) - 1):
                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
            degs = [degs[0]] + degs

            for sample_norm, deg in zip(sample_norms, degs):
                x = gen_good_cond_number_matrices(*n)
                x = normalize_to_1_operator_norm(x, sample_norm)

                mexp = torch.linalg.matrix_exp(x)
                mexp_taylor = scale_square(x, deg)

                self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0)

倒也不必都参考 torch,scipy 我觉得更可靠一点(至少在科学计算方面是经过考验的 ~):

  • 大尺寸矩阵,本地测试与 scipy 在 ubuntu 的 float32 可以到 1e-6,float64 可以到 1e-9 ~
  • 此次测试用例,在 aistudio 上 'float32': 1e-06, 'float64': 1e-15 可以过测 ~

线上 ci 保守点所以写成目前的范围,看看还要改吗?! ~~~ 😆😆😆

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改下吧,主要是避免后续一些基础算子逻辑的修改,降低了这个API精度,但是CI又感知不到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到!

@megemini megemini requested a review from cxxly December 20, 2023 05:40
@megemini
Copy link
Contributor Author

Update 20231220

  • 提升测试精度

@cxxly 请评审 ~

Copy link
Contributor

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404
Copy link
Contributor

The reference rfc in the description is not merged, and it is necessary to reference the merged rfc

@megemini
Copy link
Contributor Author

The reference rfc in the description is not merged, and it is necessary to reference the merged rfc

rfc is here: PaddlePaddle/community#775

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Dec 21, 2023

可以提交对应的中文文档

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@luotao1 luotao1 changed the title 【Hackathon 5th No.36】 为 Paddle 新增 matrix_exp API 【Hackathon 5th No.36】 为 Paddle 新增 matrix_exp API -part Dec 21, 2023
@luotao1 luotao1 merged commit db804cd into PaddlePaddle:develop Dec 21, 2023
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants