Skip to content

Commit 4ff3caf

Browse files
authored
Add rules for deprecated AMP APIs (#87)
Add codemods for `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast`, and checkers for `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.
1 parent 86186f4 commit 4ff3caf

File tree

7 files changed

+91
-6
lines changed

7 files changed

+91
-6
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
torch.cuda.amp.autocast()
4+
torch.cuda.amp.custom_fwd()
5+
torch.cuda.amp.custom_bwd()
6+
7+
dtype = torch.float32
8+
maybe_autocast = torch.cpu.amp.autocast()
9+
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
10+
maybe_autocast = torch.cpu.amp.autocast(dtype=dtype)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast
2+
4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd
3+
5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd
4+
8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
5+
9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
6+
10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
dtype = torch.float32
4+
5+
maybe_autocast = torch.cuda.amp.autocast()
6+
maybe_autocast = torch.cuda.amp.autocast(dtype=torch.bfloat16)
7+
maybe_autocast = torch.cuda.amp.autocast(dtype=dtype)
8+
9+
maybe_autocast = torch.cpu.amp.autocast()
10+
maybe_autocast = torch.cpu.amp.autocast(dtype=torch.bfloat16)
11+
maybe_autocast = torch.cpu.amp.autocast(dtype=dtype)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
dtype = torch.float32
4+
5+
maybe_autocast = torch.amp.autocast("cuda")
6+
maybe_autocast = torch.amp.autocast("cuda", dtype=torch.bfloat16)
7+
maybe_autocast = torch.amp.autocast("cuda", dtype=dtype)
8+
9+
maybe_autocast = torch.amp.autocast("cpu")
10+
maybe_autocast = torch.amp.autocast("cpu", dtype=torch.bfloat16)
11+
maybe_autocast = torch.amp.autocast("cpu", dtype=dtype)

torchfix/deprecated_symbols.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@
8383
remove_pr:
8484
reference: https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel
8585

86+
- name: torch.cuda.amp.autocast
87+
deprecate_pr: TBA
88+
remove_pr:
89+
90+
- name: torch.cuda.amp.custom_fwd
91+
deprecate_pr: TBA
92+
remove_pr:
93+
94+
- name: torch.cuda.amp.custom_bwd
95+
deprecate_pr: TBA
96+
remove_pr:
97+
98+
- name: torch.cpu.amp.autocast
99+
deprecate_pr: TBA
100+
remove_pr:
101+
86102
# functorch
87103
- name: functorch.vmap
88104
deprecate_pr: TBA

torchfix/visitors/deprecated_symbols/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
import libcst as cst
21
import pkgutil
2+
from typing import List, Optional
3+
4+
import libcst as cst
35
import yaml
4-
from typing import Optional, List
56

67
from ...common import (
7-
TorchVisitor,
8-
TorchError,
98
call_with_name_changes,
109
check_old_names_in_import_from,
10+
TorchError,
11+
TorchVisitor,
1112
)
1213

13-
from .range import call_replacement_range
14-
from .cholesky import call_replacement_cholesky
14+
from .amp import call_replacement_cpu_amp_autocast, call_replacement_cuda_amp_autocast
1515
from .chain_matmul import call_replacement_chain_matmul
16+
from .cholesky import call_replacement_cholesky
1617
from .qr import call_replacement_qr
1718

19+
from .range import call_replacement_range
20+
1821

1922
class TorchDeprecatedSymbolsVisitor(TorchVisitor):
2023
ERRORS: List[TorchError] = [
@@ -49,6 +52,8 @@ def _call_replacement(
4952
"torch.range": call_replacement_range,
5053
"torch.chain_matmul": call_replacement_chain_matmul,
5154
"torch.qr": call_replacement_qr,
55+
"torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast,
56+
"torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast,
5257
}
5358
replacement = None
5459

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import libcst as cst
2+
3+
from ...common import get_module_name
4+
5+
6+
def call_replacement_cpu_amp_autocast(node: cst.Call) -> cst.CSTNode:
7+
return _call_replacement_amp(node, "cpu")
8+
9+
10+
def call_replacement_cuda_amp_autocast(node: cst.Call) -> cst.CSTNode:
11+
return _call_replacement_amp(node, "cuda")
12+
13+
14+
def _call_replacement_amp(node: cst.Call, device: str) -> cst.CSTNode:
15+
"""
16+
Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and
17+
Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`.
18+
"""
19+
device_arg = cst.ensure_type(cst.parse_expression(f'f("{device}")'), cst.Call).args[
20+
0
21+
]
22+
23+
module_name = get_module_name(node, "torch")
24+
replacement = cst.parse_expression(f"{module_name}.amp.autocast(args)")
25+
replacement = replacement.with_changes(args=(device_arg, *node.args))
26+
return replacement

0 commit comments

Comments
 (0)