Skip to content

Commit c909236

Browse files
authored
Add codemod for TorchNonPublicAliasVisitor (#36)
* Add codemod for TorchNonPublicAliasVisitor * Format * noqa
1 parent 7cc047d commit c909236

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch.utils.data import _utils # will not be removed as it could be used for something besides default_collate
2+
batch = _utils.collate.default_collate(batch)
3+
4+
from torch.utils.data._utils import collate # also will not be removed
5+
batch = collate.default_collate(batch)
6+
7+
from torch.utils.data._utils.collate import default_collate
8+
inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx)
9+
10+
from torch.utils.data._utils.collate import default_convert
11+
values = default_convert(values)
12+
13+
import torch
14+
values = torch.utils.data._utils.collate.default_convert(values)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from torch.utils.data import dataloader, _utils # will not be removed as it could be used for something besides default_collate
2+
batch = dataloader.default_collate(batch)
3+
4+
from torch.utils.data._utils import collate # also will not be removed
5+
batch = dataloader.default_collate(batch)
6+
7+
from torch.utils.data.dataloader import default_collate
8+
inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx)
9+
10+
from torch.utils.data.dataloader import default_convert
11+
values = default_convert(values)
12+
13+
import torch
14+
values = torch.utils.data.dataloader.default_convert(values)

torchfix/visitors/nonpublic/__init__.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from os.path import commonprefix
12
from typing import Sequence
23

34
import libcst as cst
5+
from libcst.codemod.visitors import ImportItem
6+
47
from ...common import TorchVisitor
58

69

@@ -32,7 +35,32 @@ def visit_Call(self, node):
3235
public_name = self.ALIASES[qualified_name]
3336
error_code = self.ERROR_CODE[0]
3437
message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
35-
self.add_violation(node, error_code=error_code, message=message)
38+
39+
call_name = cst.helpers.get_full_name_for_node(node)
40+
replacement = None
41+
if not public_name.endswith(call_name):
42+
# We need to change the call name as it's not in the public name.
43+
# Get the new call name on the same hierarchical level.
44+
new_call_name = public_name.removeprefix(
45+
commonprefix([qualified_name.removesuffix(call_name), public_name])
46+
)
47+
new_module_name = public_name.removesuffix(new_call_name).removesuffix(
48+
"."
49+
)
50+
if new_module_name:
51+
self.needed_imports.add(
52+
ImportItem(
53+
module_name=new_module_name,
54+
obj_name=new_call_name.split(".")[0],
55+
)
56+
)
57+
replacement = node.with_changes(
58+
func=cst.parse_expression(new_call_name)
59+
)
60+
61+
self.add_violation(
62+
node, error_code=error_code, message=message, replacement=replacement
63+
)
3664

3765
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
3866
if node.module is None:
@@ -48,4 +76,20 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
4876
public_name = self.ALIASES[qualified_name]
4977
error_code = self.ERROR_CODE[1]
5078
message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
51-
self.add_violation(node, error_code=error_code, message=message)
79+
80+
new_module = ".".join(public_name.split(".")[:-1])
81+
new_name = public_name.split(".")[-1]
82+
# Replace only if the import statement has no other names
83+
if len(node.names) == 1:
84+
replacement = cst.ImportFrom(
85+
module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501
86+
names=[cst.ImportAlias(name=cst.Name(new_name))],
87+
)
88+
else:
89+
replacement = None
90+
self.add_violation(
91+
node,
92+
error_code=error_code,
93+
message=message,
94+
replacement=replacement,
95+
)

0 commit comments

Comments
 (0)