5
5
# LICENSE file in the root directory of this source tree.
6
6
"""
7
7
A simple module swap UX for a float8 version of `torch.nn.Linear` which
8
- does not require `torch.compile` to be performant..
8
+ does not require `torch.compile` to be performant.
9
9
"""
10
+ from typing import Optional
10
11
11
12
import torch
12
13
14
+ from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
15
+ from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
16
+ from torchao .float8 .float8_linear import manual_float8_matmul_with_args_in_float8
17
+ from torchao .float8 .float8_scaling_utils import NoopFwToFloat8BwDynamic
18
+ from torchao .float8 .float8_tensor import (
19
+ GemmInputRole ,
20
+ hp_tensor_and_scale_to_float8 ,
21
+ LinearMMConfig ,
22
+ ScaledMMConfig ,
23
+ )
24
+ from torchao .float8 .float8_utils import tensor_to_scale
25
+
26
+ from torchao .prototype .float8nocompile .float8nocompile_scaling_utils import (
27
+ hp_tensor_to_float8nocompile_dynamic ,
28
+ )
29
+
13
30
14
31
class Float8LinearNoCompile (torch .nn .Linear ):
15
32
"""
@@ -19,4 +36,111 @@ class Float8LinearNoCompile(torch.nn.Linear):
19
36
Note: this is **prototype** and not suitable for production use.
20
37
"""
21
38
22
- pass
39
+ def __init__ (self , * args , ** kwargs ):
40
+ """
41
+ Additional arguments on top of `torch.nn.Linear`'s arguments:
42
+ * `config`: Float8LinearConfig
43
+ """
44
+ config = kwargs .pop ("config" )
45
+ emulate = config .emulate
46
+ super ().__init__ (* args , ** kwargs )
47
+
48
+ self .config = config
49
+
50
+ self .linear_mm_config = LinearMMConfig (
51
+ # output
52
+ ScaledMMConfig (
53
+ emulate ,
54
+ self .config .gemm_config_output .use_fast_accum ,
55
+ False ,
56
+ self .config .pad_inner_dim ,
57
+ ),
58
+ # grad_input
59
+ ScaledMMConfig (
60
+ emulate ,
61
+ self .config .gemm_config_grad_input .use_fast_accum ,
62
+ False ,
63
+ self .config .pad_inner_dim ,
64
+ ),
65
+ # grad_weight
66
+ ScaledMMConfig (
67
+ emulate ,
68
+ self .config .gemm_config_grad_weight .use_fast_accum ,
69
+ False ,
70
+ self .config .pad_inner_dim ,
71
+ ),
72
+ )
73
+
74
+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
75
+ # TODO(danielvegamyhre): replace conversions with triton kernels
76
+ # TODO(danielvegamyhre): support for FSDP once dependencies are implemented
77
+ input_fp8 = self .cast_input_to_float8 (input )
78
+ weight_fp8_t = self .cast_weight_to_float8_t (self .weight )
79
+
80
+ # compute fp8 matmul
81
+ output = manual_float8_matmul_with_args_in_float8 .apply (input_fp8 , weight_fp8_t )
82
+
83
+ # cast grad_output to float8_e5m2 during backward
84
+ return self .cast_output_to_float8_in_bw (output )
85
+
86
+ def cast_input_to_float8 (self , input : torch .Tensor ) -> torch .Tensor :
87
+ # Duplicate the autocast logic for F.linear, so that the output
88
+ # of our module has the right original precision
89
+ if torch .is_autocast_enabled ():
90
+ # For now, hardcode to GPU's autocast dtype
91
+ # if we need CPU support in the future, we can add it
92
+ autocast_dtype = torch .get_autocast_gpu_dtype ()
93
+ input = input .to (autocast_dtype )
94
+
95
+ # TODO(danielvegamyhre): implement this fn in scaling_utils with call to triton kernel
96
+ return hp_tensor_to_float8nocompile_dynamic (
97
+ input ,
98
+ self .config .cast_config_input .target_dtype ,
99
+ self .linear_mm_config ,
100
+ gemm_input_role = GemmInputRole .INPUT ,
101
+ )
102
+
103
+ def cast_weight_to_float8_t (
104
+ self ,
105
+ weight : torch .Tensor ,
106
+ ) -> torch .Tensor :
107
+ # TODO(danielvegamyhre): replace conversion with triton kernel
108
+ weight_fp8 = hp_tensor_to_float8nocompile_dynamic (
109
+ weight ,
110
+ self .config .cast_config_weight .target_dtype ,
111
+ self .linear_mm_config ,
112
+ gemm_input_role = GemmInputRole .WEIGHT ,
113
+ )
114
+ return weight_fp8 .t ()
115
+
116
+ def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
117
+ # casts grad_output to float8_e5m2 for backward
118
+ # TODO(danielvegamyhre): replace conversion with triton kernel
119
+ return NoopFwToFloat8BwDynamic .apply (
120
+ output ,
121
+ self .linear_mm_config ,
122
+ self .config .cast_config_grad_output .target_dtype ,
123
+ )
124
+
125
+ @classmethod
126
+ def from_float (cls , mod ):
127
+ """
128
+ Create an nn.Linear with fp8 compute from a regular nn.Linear
129
+
130
+ Args:
131
+ mod (torch.nn.Linear): nn.Linear to convert
132
+ config (Optional[Float8LinearConfig]): configuration for conversion to float8
133
+ """
134
+ config = Float8LinearConfig ()
135
+ with torch .device ("meta" ):
136
+ new_mod = cls (
137
+ mod .in_features ,
138
+ mod .out_features ,
139
+ bias = False ,
140
+ config = config ,
141
+ )
142
+ new_mod .weight = mod .weight
143
+ new_mod .bias = mod .bias
144
+
145
+ # TODO(danielvegamyhre): support for FSDP once dependencies are implemented
146
+ return new_mod
0 commit comments