Skip to content

Add codemod for TorchNonPublicAliasVisitor #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/fixtures/nonpublic/codemod/default_collate_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch.utils.data import _utils # will not be removed as it could be used for something besides default_collate
batch = _utils.collate.default_collate(batch)

from torch.utils.data._utils import collate # also will not be removed
batch = collate.default_collate(batch)

from torch.utils.data._utils.collate import default_collate
inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx)

from torch.utils.data._utils.collate import default_convert
values = default_convert(values)

import torch
values = torch.utils.data._utils.collate.default_convert(values)
14 changes: 14 additions & 0 deletions tests/fixtures/nonpublic/codemod/default_collate_convert.py.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch.utils.data import dataloader, _utils # will not be removed as it could be used for something besides default_collate
batch = dataloader.default_collate(batch)

from torch.utils.data._utils import collate # also will not be removed
batch = dataloader.default_collate(batch)

from torch.utils.data.dataloader import default_collate
inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx)

from torch.utils.data.dataloader import default_convert
values = default_convert(values)

import torch
values = torch.utils.data.dataloader.default_convert(values)
48 changes: 46 additions & 2 deletions torchfix/visitors/nonpublic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from os.path import commonprefix
from typing import Sequence

import libcst as cst
from libcst.codemod.visitors import ImportItem

from ...common import TorchVisitor


Expand Down Expand Up @@ -32,7 +35,32 @@ def visit_Call(self, node):
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[0]
message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
self.add_violation(node, error_code=error_code, message=message)

call_name = cst.helpers.get_full_name_for_node(node)
replacement = None
if not public_name.endswith(call_name):
# We need to change the call name as it's not in the public name.
# Get the new call name on the same hierarchical level.
new_call_name = public_name.removeprefix(
commonprefix([qualified_name.removesuffix(call_name), public_name])
)
new_module_name = public_name.removesuffix(new_call_name).removesuffix(
"."
)
if new_module_name:
self.needed_imports.add(
ImportItem(
module_name=new_module_name,
obj_name=new_call_name.split(".")[0],
)
)
replacement = node.with_changes(
func=cst.parse_expression(new_call_name)
)

self.add_violation(
node, error_code=error_code, message=message, replacement=replacement
)

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if node.module is None:
Expand All @@ -48,4 +76,20 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[1]
message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
self.add_violation(node, error_code=error_code, message=message)

new_module = ".".join(public_name.split(".")[:-1])
new_name = public_name.split(".")[-1]
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.parse_expression(new_module), # type: ignore[arg-type] # noqa: E501
names=[cst.ImportAlias(name=cst.Name(new_name))],
)
else:
replacement = None
self.add_violation(
node,
error_code=error_code,
message=message,
replacement=replacement,
)