Skip to content

Commit 03ea18c

Browse files
authored
Add a rule for use_reentrant with checkpoint (#7)
1 parent 4748cfb commit 03ea18c

File tree

6 files changed

+71
-1
lines changed

6 files changed

+71
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
def gn(x, y):
3+
return torch.sigmoid(torch.matmul(x, y))
4+
5+
import torch.utils.checkpoint
6+
def fn(x, y):
7+
return checkpoint(gn, torch.sin(x), y)
8+
return checkpoint(gn, torch.sin(x), y, use_reentrant=False)
9+
10+
from torch.utils.checkpoint import checkpoint
11+
def fn(x, y):
12+
return checkpoint(gn, torch.sin(x), y)
13+
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
7:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`.
2+
12:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
from torch.utils.checkpoint import checkpoint
3+
def gn(x, y):
4+
return torch.sigmoid(torch.matmul(x, y))
5+
def fn(x, y):
6+
return checkpoint(gn, torch.sin(x), y)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import torch
2+
from torch.utils.checkpoint import checkpoint
3+
def gn(x, y):
4+
return torch.sigmoid(torch.matmul(x, y))
5+
def fn(x, y):
6+
return checkpoint(gn, torch.sin(x), y, use_reentrant=False)

torchfix/torchfix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
)
1212

1313
from .visitors.performance import TorchSynchronizedDataLoaderVisitor
14-
from .visitors.misc import TorchRequireGradVisitor
14+
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)
15+
1516
from .visitors.vision import (
1617
TorchVisionDeprecatedPretrainedVisitor,
1718
TorchVisionDeprecatedToTensorVisitor,
@@ -33,6 +34,7 @@ def GET_ALL_VISITORS():
3334
TorchVisionDeprecatedPretrainedVisitor(),
3435
TorchVisionDeprecatedToTensorVisitor(),
3536
TorchUnsafeLoadVisitor(),
37+
TorchReentrantCheckpointVisitor(),
3638
]
3739

3840

torchfix/visitors/misc/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,44 @@ def visit_Assign(self, node):
4646
replacement=replacement,
4747
)
4848
)
49+
50+
51+
class TorchReentrantCheckpointVisitor(TorchVisitor):
52+
"""
53+
Find and fix common misuse of reentrant checkpoints.
54+
"""
55+
56+
ERROR_CODE = "TOR003"
57+
MESSAGE = (
58+
"Please pass `use_reentrant` explicitly to `checkpoint`. "
59+
"To maintain old behavior, pass `use_reentrant=True`. "
60+
"It is recommended to use `use_reentrant=False`."
61+
)
62+
63+
def visit_Call(self, node):
64+
qualified_name = self.get_qualified_name_for_call(node)
65+
if qualified_name == "torch.utils.checkpoint.checkpoint":
66+
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
67+
if use_reentrant_arg is None:
68+
position_metadata = self.get_metadata(
69+
cst.metadata.WhitespaceInclusivePositionProvider, node
70+
)
71+
72+
# This codemod maybe unsafe correctness-wise
73+
# if reentrant behavior is actually needed,
74+
# so the changes need to be verified/tested.
75+
use_reentrant_arg = cst.ensure_type(
76+
cst.parse_expression("f(use_reentrant=False)"), cst.Call
77+
).args[0]
78+
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))
79+
80+
self.violations.append(
81+
LintViolation(
82+
error_code=self.ERROR_CODE,
83+
message=self.MESSAGE,
84+
line=position_metadata.start.line,
85+
column=position_metadata.start.column,
86+
node=node,
87+
replacement=replacement,
88+
)
89+
)

0 commit comments

Comments
 (0)