-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) -part #306
Changes from 1 commit
c70c56a
278dabb
18c4fcf
c2ff651
27f95ea
2b34c7e
ac77673
d012be7
f5437e4
fd743e8
b9af256
bba66f9
f0f795a
9a13e34
bfc41d6
82fe9ed
1d49bcf
45a8bf8
4e48493
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2781,6 +2781,55 @@ def generate_code(self, kwargs): | |
return code | ||
|
||
|
||
class TensorViewMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
""" | ||
def view(self, *args, **kwargs): | ||
if args: | ||
if len(args)==1 and isinstance(args[0], (tuple, list, str)): | ||
return paddle.view(self, args[0]) | ||
else: | ||
return paddle.view(self, list(args)) | ||
elif kwargs: | ||
assert 'shape_or_dtype' in kwargs | ||
return paddle.view(self, shape=kwargs['shape_or_dtype']) | ||
|
||
setattr(paddle.Tensor, 'view', view) | ||
""" | ||
) | ||
return CODE_TEMPLATE | ||
|
||
def get_paddle_class_nodes(self, func, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 辅助函数,以及判断什么情况下是 误识别,什么情况下 保持不变即可,什么情况下 保持不变+辅助函数,这里还需要细致梳理一遍代码 |
||
if kwargs: | ||
if len(kwargs) == 1 and "shape_or_dtype" in kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch的代码,不可能有"shape_or_dtype" 吧,这个是转写torch代码,传入的func、args、kwargs都是torch的内容,而且需要先self.parse_func处理func |
||
return "unchange" | ||
else: | ||
return "misidentify" | ||
|
||
if args: | ||
if len(args) > 1 and isinstance(args[0], (ast.Tuple, ast.List)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里判断逻辑不太对,应该是:
对应list/tuple/dtype三种用法,都直接保持原样就行。 |
||
return "unchange" | ||
else: | ||
self.write_aux_code() | ||
return "unchange" | ||
|
||
return "misidentify" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个尽量不要模仿其他API,而是分析这个API的自身参数特性,如果输入了 指定关键字 ,例如 |
||
|
||
|
||
class TensorView_asMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. paddle不是也有paddle.Tensor.view_as吗 |
||
def generate_code(self, kwargs): | ||
|
||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
{}.view({}.shape) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(self.paddleClass, kwargs["other"]) | ||
|
||
return code | ||
|
||
|
||
class TensorReshapeMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,60 +16,58 @@ | |
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor") | ||
obj = APIBase("torch.tensor") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
result = torch.Tensor(2, 3) | ||
result = torch.tensor([2, 3]) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"], check_value=False) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
shape = [2, 3] | ||
result = torch.Tensor(*shape) | ||
data = [2, 3] | ||
result = torch.tensor(data) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"], check_value=False) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
dim1, dim2 = 2, 3 | ||
result = torch.Tensor(dim1, dim2) | ||
data = [2, 3] | ||
result = torch.tensor(data, dtype=torch.float) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"], check_value=False) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
def fun(x: torch.Tensor): | ||
return x * 2 | ||
|
||
a = torch.Tensor(3, 4) | ||
result = fun(a) | ||
data = [2, 3] | ||
result = torch.tensor(data, dtype=torch.float, device=None) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"], check_value=False) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
result = torch.Tensor([[3, 4], [5, 8]]) | ||
data = [2, 3] | ||
result = torch.tensor(data, dtype=torch.float, device=None, requires_grad = False) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
@@ -79,8 +77,10 @@ def test_case_6(): | |
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([[3, 4], [5, 8]]) | ||
result = torch.Tensor(a) | ||
data = [2, 3] | ||
result = None | ||
if torch.cuda.is_available(): | ||
result = torch.tensor(data, requires_grad = False, pin_memory=True) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
@@ -90,22 +90,8 @@ def test_case_7(): | |
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
result = torch.Tensor((1, 2, 3)) | ||
data = [2, 3] | ||
result = torch.tensor(data, requires_grad = False, pin_memory=False) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_8(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
result = torch.Tensor() | ||
""" | ||
) | ||
obj.run( | ||
pytorch_code, | ||
["result"], | ||
unsupport=True, | ||
reason="paddle does not support 0-Size Tensor", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,31 +11,20 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.T") | ||
obj = APIBase("torch.Tensor.t") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.arange(16).reshape(4, 4) | ||
result = x.T | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
result = torch.arange(16).reshape(4, 4).T | ||
a = torch.Tensor([[1.,2.], [3.,4.]]) | ||
result = a.t() | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle有
paddle.Tensor.view_as