Skip to content

Commit 0ee382c

Browse files
authored
Reimplementation of GradNotSetToNonePattern from Torchtidy (#92)
Adding rules to check for `set_to_none` parameter for `zero_grad()`. By setting set_to_none=True, we can gain speedup
1 parent 00954c9 commit 0ee382c

File tree

5 files changed

+58
-1
lines changed

5 files changed

+58
-1
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
x = torch.ones((100, 100))
5+
model = nn.Sequential()
6+
optimizer = torch.optim.Adam(model.parameters())
7+
8+
# This should raise flags
9+
optimizer.zero_grad(set_to_none=False)
10+
model.zero_grad(set_to_none=False)
11+
12+
# This should not raise flags
13+
optimizer.zero_grad()
14+
model.zero_grad()
15+
16+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
9:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().
2+
10:1 TOR402 Detected gradient set to zero instead of None. Please add 'set_to_none=True' when calling zero_grad().

torchfix/torchfix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TorchVisionDeprecatedPretrainedVisitor,
2222
TorchVisionDeprecatedToTensorVisitor,
2323
TorchVisionSingletonImportVisitor,
24+
TorchGradNotSetToNonePatternVisitor,
2425
)
2526

2627
__version__ = "0.7.0"
@@ -43,6 +44,7 @@
4344
TorchVisionDeprecatedPretrainedVisitor,
4445
TorchVisionDeprecatedToTensorVisitor,
4546
TorchVisionSingletonImportVisitor,
47+
TorchGradNotSetToNonePatternVisitor,
4648
]
4749

4850

torchfix/visitors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
TorchRequireGradVisitor,
99
)
1010
from .nonpublic import TorchNonPublicAliasVisitor
11-
from .performance import TorchSynchronizedDataLoaderVisitor
11+
from .performance import (
12+
TorchSynchronizedDataLoaderVisitor,
13+
TorchGradNotSetToNonePatternVisitor,
14+
)
1215
from .security import TorchUnsafeLoadVisitor
1316
from .vision import (
1417
TorchVisionDeprecatedPretrainedVisitor,
@@ -30,4 +33,5 @@
3033
"TorchVisionDeprecatedPretrainedVisitor",
3134
"TorchVisionDeprecatedToTensorVisitor",
3235
"TorchVisionSingletonImportVisitor",
36+
"TorchGradNotSetToNonePatternVisitor",
3337
]

torchfix/visitors/performance/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,36 @@ def visit_Call(self, node):
3232
error_code=self.ERRORS[0].error_code,
3333
message=self.ERRORS[0].message(),
3434
)
35+
36+
37+
class TorchGradNotSetToNonePatternVisitor(TorchVisitor):
38+
"""
39+
Reimplementation of GradNotSetToNonePattern from
40+
https://github.com/pytorch/pytorch/blob/main/torch/profiler/_pattern_matcher.py
41+
"""
42+
43+
ERRORS = [
44+
TorchError(
45+
"TOR402",
46+
(
47+
"Detected gradient set to zero instead of None. "
48+
"Please add 'set_to_none=True' when calling zero_grad()."
49+
),
50+
)
51+
]
52+
53+
def visit_Call(self, node):
54+
qualified_name = self.get_qualified_name_for_call(node)
55+
56+
if qualified_name and qualified_name.endswith("zero_grad"):
57+
58+
set_to_none_arg = self.get_specific_arg(node, "set_to_none", 0)
59+
60+
# hasattr check to handle mypy error
61+
if set_to_none_arg and hasattr(set_to_none_arg.value, "value"):
62+
if set_to_none_arg.value.value == "False":
63+
self.add_violation(
64+
node,
65+
error_code=self.ERRORS[0].error_code,
66+
message=self.ERRORS[0].message(),
67+
)

0 commit comments

Comments
 (0)