@@ -41,20 +41,11 @@ def __init__(self, *args, **kwargs):
41
41
Additional arguments on top of `torch.nn.Linear`'s arguments:
42
42
* `config`: Float8LinearConfig
43
43
"""
44
-
45
- # Amax scales should always be kept as float32.
46
- self .always_float32_buffers = set ()
47
44
config = kwargs .pop ("config" )
48
45
emulate = config .emulate
49
46
super ().__init__ (* args , ** kwargs )
50
47
51
- # Defines the scaling behavior of input, weight, grad_output
52
- self .scaling_type_input = config .cast_config_input .scaling_type
53
- self .scaling_type_weight = config .cast_config_weight .scaling_type
54
- self .scaling_type_grad_output = config .cast_config_grad_output .scaling_type
55
-
56
48
self .config = config
57
- self .is_amax_initialized = not self .config .enable_amax_init
58
49
59
50
self .linear_mm_config = LinearMMConfig (
60
51
# output
@@ -81,31 +72,18 @@ def __init__(self, *args, **kwargs):
81
72
)
82
73
83
74
def forward (self , input : torch .Tensor ) -> torch .Tensor :
84
- # TODO(danielvegamyhre): modify to support for FSDP once dependencies are implemented
85
- output = self .forward_fp8_matmul (input )
86
- if self .bias is not None :
87
- output = output + self .bias .to (output .dtype )
88
- return output
89
-
90
- def forward_fp8_matmul (self , input : torch .Tensor ) -> torch .Tensor :
91
- # perform hp to fp8 conversions
92
- # TODO(danielvegamyhre): replace conversion with triton kernels
93
- input_fp8 = self .cast_input_to_float8 (input , self .is_amax_initialized )
94
- weight_scale = self .get_weight_scale (self .weight )
95
- weight_fp8_t = self .cast_weight_to_float8_t (
96
- self .weight , self .is_amax_initialized , weight_scale
97
- )
75
+ # TODO(danielvegamyhre): replace conversions with triton kernels
76
+ # TODO(danielvegamyhre): support for FSDP once dependencies are implemented
77
+ input_fp8 = self .cast_input_to_float8 (input )
78
+ weight_fp8_t = self .cast_weight_to_float8_t (self .weight )
98
79
99
80
# compute fp8 matmul
100
81
output = manual_float8_matmul_with_args_in_float8 .apply (input_fp8 , weight_fp8_t )
101
82
102
83
# cast grad_output to float8_e5m2 during backward
103
- # TODO(danielvegamyhre): replace with triton kernel
104
84
return self .cast_output_to_float8_in_bw (output )
105
85
106
- def cast_input_to_float8 (
107
- self , input : torch .Tensor , is_amax_initialized : bool
108
- ) -> torch .Tensor :
86
+ def cast_input_to_float8 (self , input : torch .Tensor ) -> torch .Tensor :
109
87
# Duplicate the autocast logic for F.linear, so that the output
110
88
# of our module has the right original precision
111
89
if torch .is_autocast_enabled ():
@@ -122,32 +100,21 @@ def cast_input_to_float8(
122
100
gemm_input_role = GemmInputRole .INPUT ,
123
101
)
124
102
125
- def get_weight_scale (self , weight : torch .Tensor ) -> Optional [torch .Tensor ]:
126
- # TODO(danielvegamyhre): replace scale calculation with triton kernel
127
- if tensor_already_casted_to_fp8 (weight ):
128
- return None
129
- return tensor_to_scale (weight , self .config .cast_config_weight .target_dtype )
130
-
131
103
def cast_weight_to_float8_t (
132
104
self ,
133
105
weight : torch .Tensor ,
134
- is_amax_initialized : bool ,
135
- weight_scale : Optional [torch .Tensor ] = None ,
136
106
) -> torch .Tensor :
137
- if tensor_already_casted_to_fp8 (weight ):
138
- return weight .t ()
139
-
140
107
# TODO(danielvegamyhre): replace conversion with triton kernel
141
- weight_fp8 = hp_tensor_and_scale_to_float8 (
108
+ weight_fp8 = hp_tensor_to_float8nocompile_dynamic (
142
109
weight ,
143
- weight_scale ,
144
110
self .config .cast_config_weight .target_dtype ,
145
111
self .linear_mm_config ,
146
112
gemm_input_role = GemmInputRole .WEIGHT ,
147
113
)
148
114
return weight_fp8 .t ()
149
115
150
116
def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
117
+ # casts grad_output to float8_e5m2 for backward
151
118
# TODO(danielvegamyhre): replace conversion with triton kernel
152
119
return NoopFwToFloat8BwDynamic .apply (
153
120
output ,
@@ -156,20 +123,15 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
156
123
)
157
124
158
125
@classmethod
159
- def from_float (
160
- cls ,
161
- mod ,
162
- config : Optional [Float8LinearConfig ] = None ,
163
- ):
126
+ def from_float (cls , mod ):
164
127
"""
165
128
Create an nn.Linear with fp8 compute from a regular nn.Linear
166
129
167
130
Args:
168
131
mod (torch.nn.Linear): nn.Linear to convert
169
132
config (Optional[Float8LinearConfig]): configuration for conversion to float8
170
133
"""
171
- if config is None :
172
- config = Float8LinearConfig ()
134
+ config = Float8LinearConfig ()
173
135
with torch .device ("meta" ):
174
136
new_mod = cls (
175
137
mod .in_features ,
0 commit comments