-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) -part #306
Changes from all commits
c70c56a
278dabb
18c4fcf
c2ff651
27f95ea
2b34c7e
ac77673
d012be7
f5437e4
fd743e8
b9af256
bba66f9
f0f795a
9a13e34
bfc41d6
82fe9ed
1d49bcf
45a8bf8
4e48493
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -356,7 +356,20 @@ | |
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.nonzero" | ||
}, | ||
"torch.Tensor.as_strided": {}, | ||
"torch.Tensor.as_strided": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.as_strided", | ||
"min_input_args": 2, | ||
"args_list": [ | ||
"size", | ||
"stride", | ||
"storage_offset" | ||
], | ||
"kwargs_change": { | ||
"size": "shape", | ||
"storage_offset": "offset" | ||
} | ||
}, | ||
"torch.Tensor.as_subclass": {}, | ||
"torch.Tensor.backward": { | ||
"Matcher": "GenericMatcher", | ||
|
@@ -730,8 +743,26 @@ | |
"memory_format" | ||
] | ||
}, | ||
"torch.Tensor.cummax": {}, | ||
"torch.Tensor.cummin": {}, | ||
"torch.Tensor.cummax": { | ||
"Matcher": "DoubleAssignMatcher", | ||
"paddle_api": "paddle.Tensor.cummax", | ||
"args_list": [ | ||
"dim" | ||
], | ||
"kwargs_change": { | ||
"dim": "axis" | ||
} | ||
}, | ||
"torch.Tensor.cummin": { | ||
"Matcher": "DoubleAssignMatcher", | ||
"paddle_api": "paddle.Tensor.cummin", | ||
"args_list": [ | ||
"dim" | ||
], | ||
"kwargs_change": { | ||
"dim": "axis" | ||
} | ||
}, | ||
"torch.Tensor.cumprod": { | ||
"Matcher": "UnchangeMatcher", | ||
"min_input_args": 1, | ||
|
@@ -1309,7 +1340,11 @@ | |
] | ||
}, | ||
"torch.Tensor.hypot_": {}, | ||
"torch.Tensor.i0": {}, | ||
"torch.Tensor.i0": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.i0", | ||
"min_input_args": 0 | ||
}, | ||
"torch.Tensor.i0_": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.i0_", | ||
|
@@ -1355,7 +1390,19 @@ | |
}, | ||
"torch.Tensor.index_fill": {}, | ||
"torch.Tensor.index_fill_": {}, | ||
"torch.Tensor.index_put": {}, | ||
"torch.Tensor.index_put": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.index_put", | ||
"min_input_args": 2, | ||
"args_list": [ | ||
"indices", | ||
"values", | ||
"accumulate" | ||
], | ||
"kwargs_change": { | ||
"values": "value" | ||
} | ||
}, | ||
"torch.Tensor.index_put_": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.index_put_", | ||
|
@@ -3033,7 +3080,19 @@ | |
"sizes": "shape" | ||
} | ||
}, | ||
"torch.Tensor.unfold": {}, | ||
"torch.Tensor.unfold": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.unfold", | ||
"min_input_args": 2, | ||
"args_list": [ | ||
"dimension", | ||
"size", | ||
"step" | ||
], | ||
"kwargs_change": { | ||
"dimension": "axis" | ||
} | ||
}, | ||
"torch.Tensor.uniform_": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.Tensor.uniform_", | ||
|
@@ -3129,8 +3188,17 @@ | |
"other": "y" | ||
} | ||
}, | ||
"torch.Tensor.view": {}, | ||
"torch.Tensor.view_as": {}, | ||
"torch.Tensor.view": { | ||
"Matcher": "TensorViewMatcher" | ||
}, | ||
"torch.Tensor.view_as": { | ||
"Matcher": "GenericMatcher", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个没有参数吗?args_list需要写 |
||
"paddle_api": "paddle.Tensor.view_as", | ||
"min_input_args": 1, | ||
"args_list": [ | ||
"other" | ||
] | ||
}, | ||
"torch.Tensor.vsplit": {}, | ||
"torch.Tensor.where": { | ||
"Matcher": "GenericMatcher", | ||
|
@@ -3571,6 +3639,22 @@ | |
"input": "x" | ||
} | ||
}, | ||
"torch.as_strided": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.as_strided", | ||
"min_input_args": 2, | ||
"args_list": [ | ||
"input", | ||
"size", | ||
"stride", | ||
"storage_offset" | ||
], | ||
"kwargs_change": { | ||
"input": "x", | ||
"size": "shape", | ||
"storage_offset": "offset" | ||
} | ||
}, | ||
"torch.as_tensor": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.to_tensor", | ||
|
@@ -4099,6 +4183,19 @@ | |
"input": "x" | ||
} | ||
}, | ||
"torch.clamp_max": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.clip", | ||
"args_list": [ | ||
"input", | ||
"max", | ||
"*", | ||
"out" | ||
], | ||
"kwargs_change": { | ||
"input": "x" | ||
} | ||
}, | ||
"torch.clamp_min": { | ||
"Matcher": "GenericMatcher", | ||
"paddle_api": "paddle.clip", | ||
|
@@ -4536,6 +4633,20 @@ | |
"device" | ||
] | ||
}, | ||
"torch.cummax": { | ||
"Matcher": "DoubleAssignMatcher", | ||
"paddle_api": "paddle.cummax", | ||
"args_list": [ | ||
"input", | ||
"dim", | ||
"*", | ||
"out" | ||
], | ||
"kwargs_change": { | ||
"input": "x", | ||
"dim": "axis" | ||
} | ||
}, | ||
"torch.cummin": { | ||
"Matcher": "DoubleAssignMatcher", | ||
"paddle_api": "paddle.cummin", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2935,6 +2935,42 @@ def generate_code(self, kwargs): | |
return code | ||
|
||
|
||
class TensorViewMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
""" | ||
def view(self, *args, **kwargs): | ||
if args: | ||
if len(args)==1 and isinstance(args[0], (tuple, list, str)): | ||
return paddle.view(self, args[0]) | ||
else: | ||
return paddle.view(self, list(args)) | ||
elif kwargs: | ||
key = [k for k in kwargs.keys()] | ||
return paddle.view(self, shape_or_dtype = kwargs[key[0]]) | ||
|
||
setattr(paddle.Tensor, 'view', view) | ||
""" | ||
) | ||
return CODE_TEMPLATE | ||
|
||
def get_paddle_class_nodes(self, func, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 辅助函数,以及判断什么情况下是 误识别,什么情况下 保持不变即可,什么情况下 保持不变+辅助函数,这里还需要细致梳理一遍代码 |
||
if kwargs: | ||
if len(kwargs) == 1: | ||
self.write_aux_code() | ||
return "unchange" | ||
|
||
if args: | ||
if len(args) == 1 and isinstance(args[0], (ast.Tuple, ast.List)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你这样改完后,就漏掉了一种情况了: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是因为 torch.view 的输入不会是str类型,所以我就把这个情况给移除了。具体说明在上面的comment中。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 实际转换时代码逻辑还是会执行到, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,已修改! |
||
return "unchange" | ||
if isinstance(args[0], (ast.Constant)) and isinstance(args[0].value, str): | ||
return "unchange" | ||
self.write_aux_code() | ||
return "unchange" | ||
|
||
return "misidentify" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个尽量不要模仿其他API,而是分析这个API的自身参数特性,如果输入了 指定关键字 ,例如 |
||
|
||
|
||
class TensorReshapeMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.as_strided") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[ 0.0335, 0.1830, -0.1269], | ||
[ 0.1897, -0.1422, -0.4940], | ||
[-0.7674, -0.0134, -0.3733]]) | ||
results = x.as_strided((2, 2), (1, 2)) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["results"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[ 0.0335, 0.1830, -0.1269], | ||
[ 0.1897, -0.1422, -0.4940], | ||
[-0.7674, -0.0134, -0.3733]]) | ||
results = x.as_strided((2, 2), (1, 2), 0) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["results"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[ 0.0335, 0.1830, -0.1269], | ||
[ 0.1897, -0.1422, -0.4940], | ||
[-0.7674, -0.0134, -0.3733]]) | ||
size = (2, 2) | ||
stride = (1, 2) | ||
storage_offset = 0 | ||
results = x.as_strided(size, stride, storage_offset) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["results"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[ 0.0335, 0.1830, -0.1269], | ||
[ 0.1897, -0.1422, -0.4940], | ||
[-0.7674, -0.0134, -0.3733]]) | ||
size = (2, 2) | ||
stride = (1, 2) | ||
results = x.as_strided(size, stride, 0) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["results"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[ 0.0335, 0.1830, -0.1269], | ||
[ 0.1897, -0.1422, -0.4940], | ||
[-0.7674, -0.0134, -0.3733]]) | ||
size = (2, 2) | ||
stride = (1, 2) | ||
results = x.as_strided(size = (2,2), stride = (2,2), storage_offset = 0) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["results"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.cummax") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[1.0, 1.0, 1.0], | ||
[2.0, 2.0, 2.0], | ||
[3.0, 3.0, 3.0]]) | ||
result = x.cummax(0) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[1.0, 1.0, 1.0], | ||
[2.0, 2.0, 2.0], | ||
[3.0, 3.0, 3.0]]) | ||
result = x.cummax(1) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[1.0, 1.0, 1.0], | ||
[2.0, 2.0, 2.0], | ||
[3.0, 3.0, 3.0]]) | ||
result = x.cummax(dim=1) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.tensor([[1.0, 1.0, 1.0], | ||
[2.0, 2.0, 2.0], | ||
[3.0, 3.0, 3.0]]) | ||
result = x.cummax(dim=0) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle有
paddle.Tensor.view_as