Skip to content

Commit 8846f5c

Browse files
authored
[Issue 7] Update import torchvision.models as models (#26)
* [Issue 7] Update import torchvision.models as models * Move torchvision.models visitor to vision dir * Move torchvision.models visitor to vision dir
1 parent 35f2488 commit 8846f5c

File tree

5 files changed

+49
-0
lines changed

5 files changed

+49
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import torchvision.models as models
2+
import torchvision.models as cnn
3+
from torchvision.models import resnet50, resnet101
4+
import torchvision.models
5+
from torchvision.models import *
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .visitors.vision import (
2020
TorchVisionDeprecatedPretrainedVisitor,
2121
TorchVisionDeprecatedToTensorVisitor,
22+
TorchVisionModelsImportVisitor,
2223
)
2324
from .visitors.security import TorchUnsafeLoadVisitor
2425

@@ -35,6 +36,7 @@
3536
TorchSynchronizedDataLoaderVisitor,
3637
TorchVisionDeprecatedPretrainedVisitor,
3738
TorchVisionDeprecatedToTensorVisitor,
39+
TorchVisionModelsImportVisitor,
3840
TorchUnsafeLoadVisitor,
3941
TorchReentrantCheckpointVisitor,
4042
]

torchfix/visitors/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
22
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401
3+
from .models_import import TorchVisionModelsImportVisitor # noqa: F401
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import libcst as cst
2+
3+
from ...common import LintViolation, TorchVisitor
4+
5+
6+
class TorchVisionModelsImportVisitor(TorchVisitor):
7+
ERROR_CODE = "TOR203"
8+
9+
def visit_Import(self, node: cst.Import) -> None:
10+
for imported_item in node.names:
11+
if isinstance(imported_item.name, cst.Attribute):
12+
if (
13+
isinstance(imported_item.name.value, cst.Name)
14+
and imported_item.name.value.value == "torchvision"
15+
and isinstance(imported_item.name.attr, cst.Name)
16+
and imported_item.name.attr.value == "models"
17+
and imported_item.asname is not None
18+
and isinstance(imported_item.asname.name, cst.Name)
19+
and imported_item.asname.name.value == "models"
20+
):
21+
position = self.get_metadata(
22+
cst.metadata.WhitespaceInclusivePositionProvider, node
23+
)
24+
replacement = cst.ImportFrom(
25+
module=cst.Name("torchvision"),
26+
names=[cst.ImportAlias(name=cst.Name("models"))],
27+
)
28+
self.violations.append(
29+
LintViolation(
30+
error_code=self.ERROR_CODE,
31+
message=(
32+
"Consider replacing 'import torchvision.models as"
33+
" models' with 'from torchvision import models'."
34+
),
35+
line=position.start.line,
36+
column=position.start.column,
37+
node=node,
38+
replacement=replacement
39+
)
40+
)

0 commit comments

Comments
 (0)