-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) -part #306
Conversation
Thanks for your contribution! |
tests/test_Tensor.py
Outdated
@@ -16,60 +16,58 @@ | |||
|
|||
from apibase import APIBase | |||
|
|||
obj = APIBase("torch.Tensor") | |||
obj = APIBase("torch.tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor
与 torch.tensor
是两个API,这个不要改动
tests/test_Tensor_T.py
Outdated
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.T") | ||
obj = APIBase("torch.Tensor.t") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor.T
与 torch.tensor.t
是两个API,这个不要改动
paconvert/api_matcher.py
Outdated
|
||
def get_paddle_class_nodes(self, func, args, kwargs): | ||
if kwargs: | ||
if len(kwargs) == 1 and "shape_or_dtype" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch的代码,不可能有"shape_or_dtype" 吧,这个是转写torch代码,传入的func、args、kwargs都是torch的内容,而且需要先self.parse_func处理func
paconvert/api_matcher.py
Outdated
return "misidentify" | ||
|
||
|
||
class TensorView_asMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle不是也有paddle.Tensor.view_as吗
"torch.Tensor.view": { | ||
"Matcher": "TensorViewMatcher" | ||
}, | ||
"torch.Tensor.view_as": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle有paddle.Tensor.view_as
) | ||
return CODE_TEMPLATE | ||
|
||
def get_paddle_class_nodes(self, func, args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辅助函数,以及判断什么情况下是 误识别,什么情况下 保持不变即可,什么情况下 保持不变+辅助函数,这里还需要细致梳理一遍代码
obj = APIBase("torch.Tensor.view") | ||
|
||
|
||
def test_case_1(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测里要包含 torch.Tensor.view
的所有用法
收到,谢谢指正,我去修改一下! |
Done @zhwesky2010 |
@zhwesky2010 请问单测中的错误 该怎么处理呢?我在开发环境中执行 注:Paddle 和 Torch 均为 CPU 版本。 上面是执行 |
|
好的 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor.view
是里面最复杂的,再结合API文档补充下测试用例:
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch-tensor-view
"Matcher": "TensorViewMatcher" | ||
}, | ||
"torch.Tensor.view_as": { | ||
"Matcher": "GenericMatcher", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有参数吗?args_list需要写
self.write_aux_code() | ||
return "unchange" | ||
|
||
return "misidentify" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个尽量不要模仿其他API,而是分析这个API的自身参数特性,如果输入了 指定关键字 ,例如x.view(dtype=torch.float32)
这个是误识别的吗?
tests/test_Tensor_view.py
Outdated
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
# # 因为当前paddle.view 不支持没有 shape 的 tensor,所以该案例无法正常运行 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
英文注释
result = a.view(torch.int32) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要增加 指定关键字 的用法,至少包含:
单测覆盖范围要求为:涉及到多个API形参的,应包含各种参数用法(
全部指定关键字、全部不指定关键字、改变关键字顺序、默认参数均不指定
四种情况必须考虑),不能只考虑最简单常见的用法,要求至少列举5种不同的使用case(越多越好)。
已修改补充! |
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_18(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
指定关键字可以是shape吗?torch.Tensor.view
的单测需要再梳理合并下,增加case种类,而不是case个数,减少重复case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,指定关键字不能是 shape, 示例代码如下:
imort torch
x = torch.arange(4)
x.view(shape=(2,2))
# 报错信息如下:
# TypeError: view() received an invalid combination of arguments - got (shape=tuple, ), but expected one of:
# * (tuple of ints size)
# didn't match because some of the keywords were incorrect: shape
# * (torch.dtype dtype)
# didn't match because some of the keywords were incorrect: shape
针对单测中,涉及的多个 dtype 类型是因为,torch.view 在改变张量类型时,当 dtype 大于或小于 self.dtype 的时候会对张量进行改变,如下代码所示:
imort torch
x = torch.arange(4)
print(x.view(torch.bool))
# tensor([False, False, False, False, False, False, False, False, True, False,
# False, False, False, False, False, False, True, False, False, False,
# False, False, False, False, True, False, False, False, False, False,
# False, False])
print(x.view(torch.half))
# tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.9605e-08, 0.0000e+00,
# 0.0000e+00, 0.0000e+00, 1.1921e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00,
# 1.7881e-07, 0.0000e+00, 0.0000e+00, 0.0000e+00], dtype=torch.float16)
print(x.view(torch.cfloat))
# tensor([0.0000e+00+0.j, 1.4013e-45+0.j, 2.8026e-45+0.j, 4.2039e-45+0.j])
print(x.view(torch.uint8))
# tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0,
# 3, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,dtype确实需要多写几种。还有一种关键字用法 x.view(size=[2, 2])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,在下一个 commit 中补充该测试案例。
paconvert/api_matcher.py
Outdated
return "unchange" | ||
|
||
if args: | ||
if len(args) > 1 and isinstance(args[0], (ast.Tuple, ast.List)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里判断逻辑不太对,应该是:
if len(args) == 1:
if isinstance(args[0], (ast.Tuple, ast.List)):
return "unchange"
if isinstance(args[0], (ast.Constant) and isinstance(args[0].value, str):
return "unchange"
对应list/tuple/dtype三种用法,都直接保持原样就行。
是的逻辑不太对,已修改。 此外 |
paconvert/api_matcher.py
Outdated
if isinstance(args[0], (ast.Constant)) and isinstance( | ||
args[0].value, str | ||
): | ||
return "unchange" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
又看了下,还有一种可能用法是输入一个变量(此变量可能是list、tuple、torch.dtype、常数),两个IF最下面还应该有一个return:
if len(args) == 1:
if isinstance(args[0], (ast.Tuple, ast.List)):
return "unchange"
if isinstance(args[0], (ast.Constant)) and isinstance(
args[0].value, str
):
return "unchange"
self.write_aux_code()
return "unchange"
else:
self.write_aux_code()
return "unchange"
在合并之后的写法为:
if len(args) == 1:
if isinstance(args[0], (ast.Tuple, ast.List)):
return "unchange"
if isinstance(args[0], (ast.Constant)) and isinstance(
args[0].value, str
):
return "unchange"
self.write_aux_code()
return "unchange"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,发现 torch.view()
并不支持输入为 str
,因此可以省略掉判断str的那部分。所有单测还是可以跑通。
import torch
x = torch.rand(4)
x.view('uint8')
# TypeError: view() received an invalid combination of arguments - got (str), but expected one of:
# * (torch.dtype dtype)
# didn't match because some of the arguments have invalid types: (str)
# * (tuple of SymInts size)
# didn't match because some of the arguments have invalid types: (str)
最终的版本如下所示:
def get_paddle_class_nodes(self, func, args, kwargs):
if kwargs:
if len(kwargs) == 1:
self.write_aux_code()
return "unchange"
if args:
if len(args) == 1 and isinstance(args[0], (ast.Tuple, ast.List)):
return "unchange"
else:
self.write_aux_code()
return "unchange"
return "misidentify"
该方案是否可行呢?
return "unchange" | ||
|
||
if args: | ||
if len(args) == 1 and isinstance(args[0], (ast.Tuple, ast.List)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你这样改完后,就漏掉了一种情况了:
if isinstance(args[0], (ast.Constant)) and isinstance(args[0].value, str):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是因为 torch.view 的输入不会是str类型,所以我就把这个情况给移除了。具体说明在上面的comment中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实际转换时代码逻辑还是会执行到,torch.view(x, torch.float32)
会先将里层的 torch.float32
转换为 "float32"
,然后再转外层的torch.view
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已修改!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,这个题目还有一些其他的API。其他5个题目也可以参与,以最终先合入的为准。
PR Docs
PR APIs