Skip to content

Commit f219838

Browse files
authored
Update --select arg to accept specific rules (#16)
1 parent 92136fe commit f219838

File tree

4 files changed

+139
-26
lines changed

4 files changed

+139
-26
lines changed

tests/test_torchfix.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
TorchChecker,
44
TorchCodemod,
55
TorchCodemodConfig,
6+
DISABLED_BY_DEFAULT,
7+
expand_error_codes,
68
GET_ALL_VISITORS,
9+
GET_ALL_ERROR_CODES,
10+
process_error_code_str,
711
)
812
import logging
913
import libcst.codemod as codemod
@@ -20,7 +24,7 @@ def _checker_results(s):
2024
def _codemod_results(source_path):
2125
with open(source_path) as source:
2226
code = source.read()
23-
config = TorchCodemodConfig(select="ALL")
27+
config = TorchCodemodConfig(select=list(GET_ALL_ERROR_CODES()))
2428
context = TorchCodemod(codemod.CodemodContext(filename=source_path), config)
2529
new_module = codemod.transform_module(context, code)
2630
return new_module.code
@@ -60,3 +64,17 @@ def test_errorcodes_distinct():
6064
for e in error_code if isinstance(error_code, list) else [error_code]:
6165
assert e not in seen
6266
seen.add(e)
67+
68+
69+
def test_parse_error_code_str():
70+
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
71+
cases = [
72+
("ALL", GET_ALL_ERROR_CODES()),
73+
("ALL,TOR102", GET_ALL_ERROR_CODES()),
74+
("TOR102", {"TOR102"}),
75+
("TOR102,TOR101", {"TOR102", "TOR101"}),
76+
("TOR1,TOR102", {"TOR102", "TOR101"}),
77+
(None, GET_ALL_ERROR_CODES() - exclude_set),
78+
]
79+
for case, expected in cases:
80+
assert expected == process_error_code_str(case)

torchfix/__main__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import sys
77
import io
88

9-
from .torchfix import TorchCodemod, TorchCodemodConfig, __version__ as TorchFixVersion
9+
from .torchfix import (
10+
TorchCodemod,
11+
TorchCodemodConfig,
12+
__version__ as TorchFixVersion,
13+
DISABLED_BY_DEFAULT,
14+
GET_ALL_ERROR_CODES,
15+
process_error_code_str,
16+
)
1017
from .common import CYAN, ENDC
1118

1219

@@ -55,10 +62,11 @@ def main() -> None:
5562
)
5663
parser.add_argument(
5764
"--select",
58-
help="ALL to enable rules disabled by default",
59-
choices=[
60-
"ALL",
61-
],
65+
help=f"Comma-separated list of rules to enable or 'ALL' to enable all rules. "
66+
f"Available rules: {', '.join(list(GET_ALL_ERROR_CODES()))}. "
67+
f"Defaults to all except for {', '.join(DISABLED_BY_DEFAULT)}.",
68+
type=str,
69+
default=None,
6270
)
6371
parser.add_argument(
6472
"--version",
@@ -94,7 +102,7 @@ def main() -> None:
94102
break
95103

96104
config = TorchCodemodConfig()
97-
config.select = args.select
105+
config.select = list(process_error_code_str(args.select))
98106
command_instance = TorchCodemod(codemod.CodemodContext(), config)
99107
DIFF_CONTEXT = 5
100108
try:

torchfix/torchfix.py

Lines changed: 105 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
2+
import functools
23
from pathlib import Path
3-
from typing import Optional
4+
from typing import Optional, List
45
import libcst as cst
56
import libcst.codemod as codemod
67

@@ -25,17 +26,100 @@
2526

2627
DISABLED_BY_DEFAULT = ["TOR3", "TOR4"]
2728

29+
ALL_VISITOR_CLS = [
30+
TorchDeprecatedSymbolsVisitor,
31+
TorchRequireGradVisitor,
32+
TorchSynchronizedDataLoaderVisitor,
33+
TorchVisionDeprecatedPretrainedVisitor,
34+
TorchVisionDeprecatedToTensorVisitor,
35+
TorchUnsafeLoadVisitor,
36+
TorchReentrantCheckpointVisitor,
37+
]
38+
39+
40+
@functools.cache
41+
def GET_ALL_ERROR_CODES():
42+
codes = set()
43+
for cls in ALL_VISITOR_CLS:
44+
if isinstance(cls.ERROR_CODE, list):
45+
codes |= set(cls.ERROR_CODE)
46+
else:
47+
codes.add(cls.ERROR_CODE)
48+
return codes
49+
50+
51+
@functools.cache
52+
def expand_error_codes(codes):
53+
out_codes = set()
54+
for c_a in codes:
55+
for c_b in GET_ALL_ERROR_CODES():
56+
if c_b.startswith(c_a):
57+
out_codes.add(c_b)
58+
return out_codes
59+
60+
61+
def construct_visitor(cls):
62+
if cls is TorchDeprecatedSymbolsVisitor:
63+
return cls(DEPRECATED_CONFIG_PATH)
64+
else:
65+
return cls()
66+
2867

2968
def GET_ALL_VISITORS():
30-
return [
31-
TorchDeprecatedSymbolsVisitor(DEPRECATED_CONFIG_PATH),
32-
TorchRequireGradVisitor(),
33-
TorchSynchronizedDataLoaderVisitor(),
34-
TorchVisionDeprecatedPretrainedVisitor(),
35-
TorchVisionDeprecatedToTensorVisitor(),
36-
TorchUnsafeLoadVisitor(),
37-
TorchReentrantCheckpointVisitor(),
38-
]
69+
out = []
70+
for v in ALL_VISITOR_CLS:
71+
out.append(construct_visitor(v))
72+
return out
73+
74+
75+
def get_visitors_with_error_codes(error_codes):
76+
visitor_classes = set()
77+
for error_code in error_codes:
78+
# Assume the error codes have been expanded so each error code can
79+
# only correspond to one visitor.
80+
found = False
81+
for visitor_cls in ALL_VISITOR_CLS:
82+
if isinstance(visitor_cls.ERROR_CODE, list):
83+
if error_code in visitor_cls.ERROR_CODE:
84+
visitor_classes.add(visitor_cls)
85+
found = True
86+
break
87+
else:
88+
if error_code == visitor_cls.ERROR_CODE:
89+
visitor_classes.add(visitor_cls)
90+
found = True
91+
break
92+
if not found:
93+
raise AssertionError(f"Unknown error code: {error_code}")
94+
out = []
95+
for cls in visitor_classes:
96+
out.append(construct_visitor(cls))
97+
return out
98+
99+
100+
def process_error_code_str(code_str):
101+
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
102+
# We deduplicate them here.
103+
104+
# Default when --select is not provided.
105+
if code_str is None:
106+
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
107+
return GET_ALL_ERROR_CODES() - exclude_set
108+
109+
raw_codes = [s.strip() for s in code_str.split(",")]
110+
111+
# Validate error codes
112+
for c in raw_codes:
113+
if c == "ALL":
114+
continue
115+
if len(expand_error_codes((c,))) == 0:
116+
raise ValueError(f"Invalid error code: {c}, available error "
117+
f"codes: {list(GET_ALL_ERROR_CODES())}")
118+
119+
if "ALL" in raw_codes:
120+
return GET_ALL_ERROR_CODES()
121+
122+
return expand_error_codes(tuple(raw_codes))
39123

40124

41125
# Flake8 plugin
@@ -78,7 +162,7 @@ def add_options(optmanager):
78162
# Standalone torchfix command
79163
@dataclass
80164
class TorchCodemodConfig:
81-
select: Optional[str] = None
165+
select: Optional[List[str]] = None
82166

83167

84168
class TorchCodemod(codemod.Codemod):
@@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
97181
# in that case we would need to use `wrapped_module.module`
98182
# instead of `module`.
99183
wrapped_module = cst.MetadataWrapper(module, unsafe_skip_copy=True)
184+
if self.config is None or self.config.select is None:
185+
raise AssertionError("Expected self.config.select to be set")
186+
visitors = get_visitors_with_error_codes(self.config.select)
100187

101-
visitors = GET_ALL_VISITORS()
102188
violations = []
103189
needed_imports = []
104190
wrapped_module.visit_batched(visitors)
@@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
110196
replacement_map = {}
111197
assert self.context.filename is not None
112198
for violation in violations:
113-
skip_violation = False
114-
if self.config is None or self.config.select != "ALL":
115-
for disabled_code in DISABLED_BY_DEFAULT:
116-
if violation.error_code.startswith(disabled_code):
117-
skip_violation = True
118-
break
199+
# Still need to skip violations here, since a single visitor can
200+
# correspond to multiple different types of violations.
201+
skip_violation = True
202+
for code in self.config.select:
203+
if violation.error_code.startswith(code):
204+
skip_violation = False
205+
break
119206
if skip_violation:
120207
continue
121208

torchfix/visitors/vision/to_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def visit_ImportFrom(self, node):
4343

4444
def visit_Attribute(self, node):
4545
qualified_names = self.get_metadata(cst.metadata.QualifiedNameProvider, node)
46-
if not len(qualified_names) == 1:
46+
if len(qualified_names) != 1:
4747
return
4848

4949
self._maybe_add_violation(qualified_names.pop().name, node)

0 commit comments

Comments
 (0)