Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) -part #306

Merged
merged 19 commits into from
Oct 20, 2023

Conversation

Li-fAngyU
Copy link
Contributor

@Li-fAngyU Li-fAngyU commented Sep 27, 2023

PR Docs

PR APIs

2.torch.cummax
3.torch.Tensor.cummax
4.torch.Tensor.cummin
6.torch.Tensor.view
7.torch.Tensor.view_as
8.torch.Tensor.i0
9.torch.Tensor.index_out
14.torch.as_strided
15.torch.Tensor.as_strided
18.torch.Tensor.unfold
20.torch.clamp_max

@paddle-bot
Copy link

paddle-bot bot commented Sep 27, 2023

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Sep 27, 2023
@@ -16,60 +16,58 @@

from apibase import APIBase

obj = APIBase("torch.Tensor")
obj = APIBase("torch.tensor")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensortorch.tensor 是两个API,这个不要改动


import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.T")
obj = APIBase("torch.Tensor.t")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor.Ttorch.tensor.t 是两个API,这个不要改动


def get_paddle_class_nodes(self, func, args, kwargs):
if kwargs:
if len(kwargs) == 1 and "shape_or_dtype" in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch的代码,不可能有"shape_or_dtype" 吧,这个是转写torch代码,传入的func、args、kwargs都是torch的内容,而且需要先self.parse_func处理func

return "misidentify"


class TensorView_asMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle不是也有paddle.Tensor.view_as吗

"torch.Tensor.view": {
"Matcher": "TensorViewMatcher"
},
"torch.Tensor.view_as": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle有paddle.Tensor.view_as

)
return CODE_TEMPLATE

def get_paddle_class_nodes(self, func, args, kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辅助函数,以及判断什么情况下是 误识别,什么情况下 保持不变即可,什么情况下 保持不变+辅助函数,这里还需要细致梳理一遍代码

obj = APIBase("torch.Tensor.view")


def test_case_1():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测里要包含 torch.Tensor.view 的所有用法

@Li-fAngyU
Copy link
Contributor Author

收到,谢谢指正,我去修改一下!

@Li-fAngyU
Copy link
Contributor Author

Done @zhwesky2010

@Li-fAngyU Li-fAngyU requested a review from zhwesky2010 October 11, 2023 16:44
@Li-fAngyU
Copy link
Contributor Author

Li-fAngyU commented Oct 12, 2023

@zhwesky2010 请问单测中的错误 该怎么处理呢?我在开发环境中执行 pytest tests/test_Tensor_view.py 显示案例全通过,但是执行 pytest tests 测试所有单测时就会出现 Type Error 可以帮忙看看是什么原因呢?

注:Paddle 和 Torch 均为 CPU 版本。

e843d7a04a3bdc2b52e262f1ac45c36

上面是执行 pytest tests 时的报错,下面是单独执行 pytest test/test_Tensor_view.py 测试的结果

@zhwesky2010
Copy link
Collaborator

@zhwesky2010 请问单测中的错误 该怎么处理呢?我在开发环境中执行 pytest tests/test_Tensor_view.py 显示案例全通过,但是执行 pytest tests 测试所有单测时就会出现 Type Error 可以帮忙看看是什么原因呢?

注:Paddle 和 Torch 均为 CPU 版本。

e843d7a04a3bdc2b52e262f1ac45c36

上面是执行 pytest tests 时的报错,下面是单独执行 pytest test/test_Tensor_view.py 测试的结果

单测里加一个is_aux_api:
infoflow 2023-10-17 12-22-09

@Li-fAngyU
Copy link
Contributor Author

好的

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor.view是里面最复杂的,再结合API文档补充下测试用例:
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch-tensor-view

"Matcher": "TensorViewMatcher"
},
"torch.Tensor.view_as": {
"Matcher": "GenericMatcher",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个没有参数吗?args_list需要写

self.write_aux_code()
return "unchange"

return "misidentify"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个尽量不要模仿其他API,而是分析这个API的自身参数特性,如果输入了 指定关键字 ,例如x.view(dtype=torch.float32) 这个是误识别的吗?

obj.run(pytorch_code, ["result"])


# # 因为当前paddle.view 不支持没有 shape 的 tensor,所以该案例无法正常运行
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

英文注释

result = a.view(torch.int32)
"""
)
obj.run(pytorch_code, ["result"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要增加 指定关键字 的用法,至少包含:

单测覆盖范围要求为:涉及到多个API形参的,应包含各种参数用法( 全部指定关键字、全部不指定关键字、改变关键字顺序、默认参数均不指定 四种情况必须考虑),不能只考虑最简单常见的用法,要求至少列举5种不同的使用case(越多越好)。

@Li-fAngyU
Copy link
Contributor Author

已修改补充!

@Li-fAngyU Li-fAngyU requested a review from zhwesky2010 October 17, 2023 12:29
obj.run(pytorch_code, ["result"])


def test_case_18():
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

指定关键字可以是shape吗?torch.Tensor.view 的单测需要再梳理合并下,增加case种类,而不是case个数,减少重复case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,指定关键字不能是 shape, 示例代码如下:

imort torch
x = torch.arange(4)

x.view(shape=(2,2))
# 报错信息如下:
# TypeError: view() received an invalid combination of arguments - got (shape=tuple, ), but expected one of:
# * (tuple of ints size)
#      didn't match because some of the keywords were incorrect: shape
# * (torch.dtype dtype)
 #     didn't match because some of the keywords were incorrect: shape

针对单测中,涉及的多个 dtype 类型是因为,torch.view 在改变张量类型时,当 dtype 大于或小于 self.dtype 的时候会对张量进行改变,如下代码所示:

imort torch
x = torch.arange(4)

print(x.view(torch.bool))
# tensor([False, False, False, False, False, False, False, False,  True, False,
#        False, False, False, False, False, False,  True, False, False, False,
#       False, False, False, False,  True, False, False, False, False, False,
#       False, False])

print(x.view(torch.half))
# tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.9605e-08, 0.0000e+00,
#        0.0000e+00, 0.0000e+00, 1.1921e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00,
#        1.7881e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00], dtype=torch.float16)

print(x.view(torch.cfloat))
# tensor([0.0000e+00+0.j, 1.4013e-45+0.j, 2.8026e-45+0.j, 4.2039e-45+0.j])

print(x.view(torch.uint8))
# tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0,
#        3, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,dtype确实需要多写几种。还有一种关键字用法 x.view(size=[2, 2])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,在下一个 commit 中补充该测试案例。

@Li-fAngyU Li-fAngyU requested a review from zhwesky2010 October 18, 2023 07:54
return "unchange"

if args:
if len(args) > 1 and isinstance(args[0], (ast.Tuple, ast.List)):
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里判断逻辑不太对,应该是:

if len(args) == 1:
    if isinstance(args[0], (ast.Tuple, ast.List)):
           return "unchange"
    if isinstance(args[0], (ast.Constant) and isinstance(args[0].value, str):
           return "unchange"

对应list/tuple/dtype三种用法,都直接保持原样就行。

@Li-fAngyU
Copy link
Contributor Author

是的逻辑不太对,已修改。 此外 torch.is_pinned 和该组的其他 API 可否另起一个PR提交?

@Li-fAngyU Li-fAngyU requested a review from zhwesky2010 October 19, 2023 06:52
@luotao1 luotao1 changed the title 【Hackthon 5 No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) 【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) Oct 19, 2023
if isinstance(args[0], (ast.Constant)) and isinstance(
args[0].value, str
):
return "unchange"
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

又看了下,还有一种可能用法是输入一个变量(此变量可能是list、tuple、torch.dtype、常数),两个IF最下面还应该有一个return:

if len(args) == 1:
                if isinstance(args[0], (ast.Tuple, ast.List)):
                    return "unchange"
                if isinstance(args[0], (ast.Constant)) and isinstance(
                    args[0].value, str
                ):
                    return "unchange"

                self.write_aux_code()
                return "unchange"
else:
       self.write_aux_code()
       return "unchange"

在合并之后的写法为:

if len(args) == 1:
                if isinstance(args[0], (ast.Tuple, ast.List)):
                    return "unchange"
                if isinstance(args[0], (ast.Constant)) and isinstance(
                    args[0].value, str
                ):
                    return "unchange"

self.write_aux_code()
return "unchange"

Copy link
Contributor Author

@Li-fAngyU Li-fAngyU Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,发现 torch.view() 并不支持输入为 str,因此可以省略掉判断str的那部分。所有单测还是可以跑通。

import torch
x = torch.rand(4)
x.view('uint8')
#   TypeError: view() received an invalid combination of arguments - got (str), but expected one of:
#    * (torch.dtype dtype)
#         didn't match because some of the arguments have invalid types: (str)
#    * (tuple of SymInts size)
#         didn't match because some of the arguments have invalid types: (str)

最终的版本如下所示:

def get_paddle_class_nodes(self, func, args, kwargs):
    if kwargs:
        if len(kwargs) == 1:
            self.write_aux_code()
            return "unchange"

    if args:
        if len(args) == 1 and isinstance(args[0], (ast.Tuple, ast.List)):
            return "unchange"
        else:
            self.write_aux_code()
            return "unchange"

    return "misidentify"

该方案是否可行呢?

return "unchange"

if args:
if len(args) == 1 and isinstance(args[0], (ast.Tuple, ast.List)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这样改完后,就漏掉了一种情况了:
if isinstance(args[0], (ast.Constant)) and isinstance(args[0].value, str):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是因为 torch.view 的输入不会是str类型,所以我就把这个情况给移除了。具体说明在上面的comment中。

Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实际转换时代码逻辑还是会执行到,torch.view(x, torch.float32) 会先将里层的 torch.float32 转换为 "float32",然后再转外层的torch.view

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已修改!

Copy link
Collaborator

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM,这个题目还有一些其他的API。其他5个题目也可以参与,以最终先合入的为准。

@zhwesky2010 zhwesky2010 merged commit 1f5b7b7 into PaddlePaddle:master Oct 20, 2023
@luotao1 luotao1 changed the title 【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) 【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) -part Oct 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants