File tree Expand file tree Collapse file tree 5 files changed +58
-1
lines changed
tests/fixtures/performance/checker Expand file tree Collapse file tree 5 files changed +58
-1
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+ import torch .nn as nn
3
+
4
+ x = torch .ones ((100 , 100 ))
5
+ model = nn .Sequential ()
6
+ optimizer = torch .optim .Adam (model .parameters ())
7
+
8
+ # This should raise flags
9
+ optimizer .zero_grad (set_to_none = False )
10
+ model .zero_grad (set_to_none = False )
11
+
12
+ # This should not raise flags
13
+ optimizer .zero_grad ()
14
+ model .zero_grad ()
15
+
16
+
Original file line number Diff line number Diff line change
1
+ 9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().
2
+ 10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().
Original file line number Diff line number Diff line change 21
21
TorchVisionDeprecatedPretrainedVisitor ,
22
22
TorchVisionDeprecatedToTensorVisitor ,
23
23
TorchVisionSingletonImportVisitor ,
24
+ TorchGradNotSetToNonePatternVisitor ,
24
25
)
25
26
26
27
__version__ = "0.7.0"
43
44
TorchVisionDeprecatedPretrainedVisitor ,
44
45
TorchVisionDeprecatedToTensorVisitor ,
45
46
TorchVisionSingletonImportVisitor ,
47
+ TorchGradNotSetToNonePatternVisitor ,
46
48
]
47
49
48
50
Original file line number Diff line number Diff line change 8
8
TorchRequireGradVisitor ,
9
9
)
10
10
from .nonpublic import TorchNonPublicAliasVisitor
11
- from .performance import TorchSynchronizedDataLoaderVisitor
11
+ from .performance import (
12
+ TorchSynchronizedDataLoaderVisitor ,
13
+ TorchGradNotSetToNonePatternVisitor ,
14
+ )
12
15
from .security import TorchUnsafeLoadVisitor
13
16
from .vision import (
14
17
TorchVisionDeprecatedPretrainedVisitor ,
30
33
"TorchVisionDeprecatedPretrainedVisitor" ,
31
34
"TorchVisionDeprecatedToTensorVisitor" ,
32
35
"TorchVisionSingletonImportVisitor" ,
36
+ "TorchGradNotSetToNonePatternVisitor" ,
33
37
]
Original file line number Diff line number Diff line change @@ -32,3 +32,36 @@ def visit_Call(self, node):
32
32
error_code = self .ERRORS [0 ].error_code ,
33
33
message = self .ERRORS [0 ].message (),
34
34
)
35
+
36
+
37
+ class TorchGradNotSetToNonePatternVisitor (TorchVisitor ):
38
+ """
39
+ Reimplementation of GradNotSetToNonePattern from
40
+ https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py
41
+ """
42
+
43
+ ERRORS = [
44
+ TorchError (
45
+ "TOR402" ,
46
+ (
47
+ "Detected gradient set to zero instead of None. "
48
+ "Please add 'set_to_none=True' when calling zero_grad()."
49
+ ),
50
+ )
51
+ ]
52
+
53
+ def visit_Call (self , node ):
54
+ qualified_name = self .get_qualified_name_for_call (node )
55
+
56
+ if qualified_name and qualified_name .endswith ("zero_grad" ):
57
+
58
+ set_to_none_arg = self .get_specific_arg (node , "set_to_none" , 0 )
59
+
60
+ # hasattr check to handle mypy error
61
+ if set_to_none_arg and hasattr (set_to_none_arg .value , "value" ):
62
+ if set_to_none_arg .value .value == "False" :
63
+ self .add_violation (
64
+ node ,
65
+ error_code = self .ERRORS [0 ].error_code ,
66
+ message = self .ERRORS [0 ].message (),
67
+ )
You can’t perform that action at this time.
0 commit comments