File tree Expand file tree Collapse file tree 7 files changed +91
-6
lines changed
tests/fixtures/deprecated_symbols
visitors/deprecated_symbols Expand file tree Collapse file tree 7 files changed +91
-6
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+ torch .cuda .amp .autocast ()
4
+ torch .cuda .amp .custom_fwd ()
5
+ torch .cuda .amp .custom_bwd ()
6
+
7
+ dtype = torch .float32
8
+ maybe_autocast = torch .cpu .amp .autocast ()
9
+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
10
+ maybe_autocast = torch .cpu .amp .autocast (dtype = dtype )
Original file line number Diff line number Diff line change
1
+ 3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast
2
+ 4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd
3
+ 5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd
4
+ 8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
5
+ 9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
6
+ 10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+ dtype = torch .float32
4
+
5
+ maybe_autocast = torch .cuda .amp .autocast ()
6
+ maybe_autocast = torch .cuda .amp .autocast (dtype = torch .bfloat16 )
7
+ maybe_autocast = torch .cuda .amp .autocast (dtype = dtype )
8
+
9
+ maybe_autocast = torch .cpu .amp .autocast ()
10
+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
11
+ maybe_autocast = torch .cpu .amp .autocast (dtype = dtype )
Original file line number Diff line number Diff line change
1
+ import torch
2
+
3
+ dtype = torch .float32
4
+
5
+ maybe_autocast = torch .amp .autocast ("cuda" )
6
+ maybe_autocast = torch .amp .autocast ("cuda" , dtype = torch .bfloat16 )
7
+ maybe_autocast = torch .amp .autocast ("cuda" , dtype = dtype )
8
+
9
+ maybe_autocast = torch .amp .autocast ("cpu" )
10
+ maybe_autocast = torch .amp .autocast ("cpu" , dtype = torch .bfloat16 )
11
+ maybe_autocast = torch .amp .autocast ("cpu" , dtype = dtype )
Original file line number Diff line number Diff line change 83
83
remove_pr :
84
84
reference : https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel
85
85
86
+ - name : torch.cuda.amp.autocast
87
+ deprecate_pr : TBA
88
+ remove_pr :
89
+
90
+ - name : torch.cuda.amp.custom_fwd
91
+ deprecate_pr : TBA
92
+ remove_pr :
93
+
94
+ - name : torch.cuda.amp.custom_bwd
95
+ deprecate_pr : TBA
96
+ remove_pr :
97
+
98
+ - name : torch.cpu.amp.autocast
99
+ deprecate_pr : TBA
100
+ remove_pr :
101
+
86
102
# functorch
87
103
- name : functorch.vmap
88
104
deprecate_pr : TBA
Original file line number Diff line number Diff line change 1
- import libcst as cst
2
1
import pkgutil
2
+ from typing import List , Optional
3
+
4
+ import libcst as cst
3
5
import yaml
4
- from typing import Optional , List
5
6
6
7
from ...common import (
7
- TorchVisitor ,
8
- TorchError ,
9
8
call_with_name_changes ,
10
9
check_old_names_in_import_from ,
10
+ TorchError ,
11
+ TorchVisitor ,
11
12
)
12
13
13
- from .range import call_replacement_range
14
- from .cholesky import call_replacement_cholesky
14
+ from .amp import call_replacement_cpu_amp_autocast , call_replacement_cuda_amp_autocast
15
15
from .chain_matmul import call_replacement_chain_matmul
16
+ from .cholesky import call_replacement_cholesky
16
17
from .qr import call_replacement_qr
17
18
19
+ from .range import call_replacement_range
20
+
18
21
19
22
class TorchDeprecatedSymbolsVisitor (TorchVisitor ):
20
23
ERRORS : List [TorchError ] = [
@@ -49,6 +52,8 @@ def _call_replacement(
49
52
"torch.range" : call_replacement_range ,
50
53
"torch.chain_matmul" : call_replacement_chain_matmul ,
51
54
"torch.qr" : call_replacement_qr ,
55
+ "torch.cuda.amp.autocast" : call_replacement_cuda_amp_autocast ,
56
+ "torch.cpu.amp.autocast" : call_replacement_cpu_amp_autocast ,
52
57
}
53
58
replacement = None
54
59
Original file line number Diff line number Diff line change
1
+ import libcst as cst
2
+
3
+ from ...common import get_module_name
4
+
5
+
6
+ def call_replacement_cpu_amp_autocast (node : cst .Call ) -> cst .CSTNode :
7
+ return _call_replacement_amp (node , "cpu" )
8
+
9
+
10
+ def call_replacement_cuda_amp_autocast (node : cst .Call ) -> cst .CSTNode :
11
+ return _call_replacement_amp (node , "cuda" )
12
+
13
+
14
+ def _call_replacement_amp (node : cst .Call , device : str ) -> cst .CSTNode :
15
+ """
16
+ Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and
17
+ Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`.
18
+ """
19
+ device_arg = cst .ensure_type (cst .parse_expression (f'f("{ device } ")' ), cst .Call ).args [
20
+ 0
21
+ ]
22
+
23
+ module_name = get_module_name (node , "torch" )
24
+ replacement = cst .parse_expression (f"{ module_name } .amp.autocast(args)" )
25
+ replacement = replacement .with_changes (args = (device_arg , * node .args ))
26
+ return replacement
You can’t perform that action at this time.
0 commit comments