1
1
from dataclasses import dataclass
2
+ import functools
2
3
from pathlib import Path
3
- from typing import Optional
4
+ from typing import Optional , List
4
5
import libcst as cst
5
6
import libcst .codemod as codemod
6
7
25
26
26
27
DISABLED_BY_DEFAULT = ["TOR3" , "TOR4" ]
27
28
29
+ ALL_VISITOR_CLS = [
30
+ TorchDeprecatedSymbolsVisitor ,
31
+ TorchRequireGradVisitor ,
32
+ TorchSynchronizedDataLoaderVisitor ,
33
+ TorchVisionDeprecatedPretrainedVisitor ,
34
+ TorchVisionDeprecatedToTensorVisitor ,
35
+ TorchUnsafeLoadVisitor ,
36
+ TorchReentrantCheckpointVisitor ,
37
+ ]
38
+
39
+
40
+ @functools .cache
41
+ def GET_ALL_ERROR_CODES ():
42
+ codes = set ()
43
+ for cls in ALL_VISITOR_CLS :
44
+ if isinstance (cls .ERROR_CODE , list ):
45
+ codes |= set (cls .ERROR_CODE )
46
+ else :
47
+ codes .add (cls .ERROR_CODE )
48
+ return codes
49
+
50
+
51
+ @functools .cache
52
+ def expand_error_codes (codes ):
53
+ out_codes = set ()
54
+ for c_a in codes :
55
+ for c_b in GET_ALL_ERROR_CODES ():
56
+ if c_b .startswith (c_a ):
57
+ out_codes .add (c_b )
58
+ return out_codes
59
+
60
+
61
+ def construct_visitor (cls ):
62
+ if cls is TorchDeprecatedSymbolsVisitor :
63
+ return cls (DEPRECATED_CONFIG_PATH )
64
+ else :
65
+ return cls ()
66
+
28
67
29
68
def GET_ALL_VISITORS ():
30
- return [
31
- TorchDeprecatedSymbolsVisitor (DEPRECATED_CONFIG_PATH ),
32
- TorchRequireGradVisitor (),
33
- TorchSynchronizedDataLoaderVisitor (),
34
- TorchVisionDeprecatedPretrainedVisitor (),
35
- TorchVisionDeprecatedToTensorVisitor (),
36
- TorchUnsafeLoadVisitor (),
37
- TorchReentrantCheckpointVisitor (),
38
- ]
69
+ out = []
70
+ for v in ALL_VISITOR_CLS :
71
+ out .append (construct_visitor (v ))
72
+ return out
73
+
74
+
75
+ def get_visitors_with_error_codes (error_codes ):
76
+ visitor_classes = set ()
77
+ for error_code in error_codes :
78
+ # Assume the error codes have been expanded so each error code can
79
+ # only correspond to one visitor.
80
+ found = False
81
+ for visitor_cls in ALL_VISITOR_CLS :
82
+ if isinstance (visitor_cls .ERROR_CODE , list ):
83
+ if error_code in visitor_cls .ERROR_CODE :
84
+ visitor_classes .add (visitor_cls )
85
+ found = True
86
+ break
87
+ else :
88
+ if error_code == visitor_cls .ERROR_CODE :
89
+ visitor_classes .add (visitor_cls )
90
+ found = True
91
+ break
92
+ if not found :
93
+ raise AssertionError (f"Unknown error code: { error_code } " )
94
+ out = []
95
+ for cls in visitor_classes :
96
+ out .append (construct_visitor (cls ))
97
+ return out
98
+
99
+
100
+ def process_error_code_str (code_str ):
101
+ # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
102
+ # We deduplicate them here.
103
+
104
+ # Default when --select is not provided.
105
+ if code_str is None :
106
+ exclude_set = expand_error_codes (tuple (DISABLED_BY_DEFAULT ))
107
+ return GET_ALL_ERROR_CODES () - exclude_set
108
+
109
+ raw_codes = [s .strip () for s in code_str .split ("," )]
110
+
111
+ # Validate error codes
112
+ for c in raw_codes :
113
+ if c == "ALL" :
114
+ continue
115
+ if len (expand_error_codes ((c ,))) == 0 :
116
+ raise ValueError (f"Invalid error code: { c } , available error "
117
+ f"codes: { list (GET_ALL_ERROR_CODES ())} " )
118
+
119
+ if "ALL" in raw_codes :
120
+ return GET_ALL_ERROR_CODES ()
121
+
122
+ return expand_error_codes (tuple (raw_codes ))
39
123
40
124
41
125
# Flake8 plugin
@@ -78,7 +162,7 @@ def add_options(optmanager):
78
162
# Standalone torchfix command
79
163
@dataclass
80
164
class TorchCodemodConfig :
81
- select : Optional [str ] = None
165
+ select : Optional [List [ str ] ] = None
82
166
83
167
84
168
class TorchCodemod (codemod .Codemod ):
@@ -97,8 +181,10 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
97
181
# in that case we would need to use `wrapped_module.module`
98
182
# instead of `module`.
99
183
wrapped_module = cst .MetadataWrapper (module , unsafe_skip_copy = True )
184
+ if self .config is None or self .config .select is None :
185
+ raise AssertionError ("Expected self.config.select to be set" )
186
+ visitors = get_visitors_with_error_codes (self .config .select )
100
187
101
- visitors = GET_ALL_VISITORS ()
102
188
violations = []
103
189
needed_imports = []
104
190
wrapped_module .visit_batched (visitors )
@@ -110,12 +196,13 @@ def transform_module_impl(self, module: cst.Module) -> cst.Module:
110
196
replacement_map = {}
111
197
assert self .context .filename is not None
112
198
for violation in violations :
113
- skip_violation = False
114
- if self .config is None or self .config .select != "ALL" :
115
- for disabled_code in DISABLED_BY_DEFAULT :
116
- if violation .error_code .startswith (disabled_code ):
117
- skip_violation = True
118
- break
199
+ # Still need to skip violations here, since a single visitor can
200
+ # correspond to multiple different types of violations.
201
+ skip_violation = True
202
+ for code in self .config .select :
203
+ if violation .error_code .startswith (code ):
204
+ skip_violation = False
205
+ break
119
206
if skip_violation :
120
207
continue
121
208
0 commit comments