25
25
tensor_already_casted_to_fp8 ,
26
26
to_fp8_no_autograd ,
27
27
)
28
- from float8_experimental .float8_utils import e4m3_dtype , tensor_to_scale
28
+ from float8_experimental .float8_utils import (
29
+ e4m3_dtype ,
30
+ get_supported_granularity ,
31
+ tensor_to_scale ,
32
+ )
33
+
34
+ SUPPORTED_GRANULARITY = get_supported_granularity ()
29
35
30
36
31
37
class ActivationCasting (Enum ):
@@ -75,7 +81,7 @@ def __init__(
75
81
# FP8 specific arguments
76
82
quant_config : QuantConfig ,
77
83
forward_config : ScaledMMConfig ,
78
- scaling_granularity : ScalingGranularity ,
84
+ scaling_granularity : Optional [ ScalingGranularity ] ,
79
85
# nn.Linear arguments
80
86
in_features : int ,
81
87
out_features : int ,
@@ -86,7 +92,26 @@ def __init__(
86
92
# Construct the superclass this will create dummy weights and biases
87
93
super ().__init__ (in_features , out_features , bias , device , dtype )
88
94
self .forward_config = forward_config
89
- self .scaling_granularity = scaling_granularity
95
+ if scaling_granularity is None :
96
+ self .scaling_granularity = (
97
+ ScalingGranularity .AxisWise
98
+ if dtype == torch .bfloat16
99
+ and quant_config .static_quantization_scale is None
100
+ else ScalingGranularity .TensorWise
101
+ )
102
+ else :
103
+ assert (
104
+ scaling_granularity in SUPPORTED_GRANULARITY
105
+ ), f"scaling_granularity must be in { SUPPORTED_GRANULARITY } but got { scaling_granularity } "
106
+ if (
107
+ scaling_granularity == ScalingGranularity .AxisWise
108
+ and dtype != torch .bfloat16
109
+ ):
110
+ raise ValueError (
111
+ "AxisWise scaling granularity is only supported for bfloat16."
112
+ )
113
+ self .scaling_granularity = scaling_granularity
114
+
90
115
self .activation_casting = quant_config .activation_casting
91
116
if self .activation_casting == ActivationCasting .STATIC :
92
117
self .register_buffer (
@@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101
126
input , self .weight .to_original_precision ()
102
127
)
103
128
129
+ # TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130
+ original_m = input .shape [:- 1 ]
131
+ input = input .view (- 1 , input .shape [- 1 ])
132
+
104
133
x_fp8 = cast_to_float8_e4m3_inference (
105
134
input ,
106
135
self .forward_config ,
107
136
static_quantization_scale = self .static_quantization_scale ,
108
137
scaling_granularity = self .scaling_granularity ,
109
138
)
110
- return torch .nn .functional .linear (x_fp8 , self .weight , self .bias )
139
+ return torch .nn .functional .linear (x_fp8 , self .weight , self .bias ).view (
140
+ * original_m , - 1
141
+ )
142
+
143
+ def extra_repr (self ):
144
+ return f"{ super ().extra_repr ()} ,activation_casting={ self .activation_casting .name } ,scaling_granularity={ self .scaling_granularity .name } "
111
145
112
146
# Builder functions for Float8LinearInference
113
147
def quantize_weight (self , dtype : torch .dtype = e4m3_dtype ) -> None :
@@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124
158
assert not isinstance (
125
159
self .weight , Float8Tensor
126
160
), "Weight has already been quantized, cannot quantize again."
127
- scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity )
161
+
162
+ # For weight tensors + AxisWise we calculate scales along columns
163
+ dim = None
164
+ if self .scaling_granularity == ScalingGranularity .AxisWise :
165
+ dim = 1
166
+ scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity , dim = dim )
128
167
quantized_weight = to_fp8_no_autograd (
129
168
self .weight , scale , dtype , self .forward_config
130
169
)
@@ -143,19 +182,20 @@ def from_float(
143
182
module : nn .Module ,
144
183
quant_config : QuantConfig ,
145
184
use_fast_accum : bool ,
185
+ scaling_granularity : Optional [ScalingGranularity ],
146
186
) -> "Float8InferenceLinear" :
147
187
"""
148
188
Create an nn.Linear with fp8 compute from another nn.Linear
149
189
150
190
Args:
151
191
mod (torch.nn.Linear): nn.Linear to convert
152
192
quant_config (QuantConfig): Configuration for the weight and activation casting
193
+ use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
194
+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153
195
"""
154
196
forward_config = ScaledMMConfig (
155
197
False , use_fast_accum , pad_inner_dim = config .pad_inner_dim
156
198
)
157
- # TODO: For now hardcode TensorWise scaling
158
- scaling_granularity = ScalingGranularity .TensorWise
159
199
linear = cls (
160
200
quant_config ,
161
201
forward_config ,
@@ -164,6 +204,7 @@ def from_float(
164
204
module .out_features ,
165
205
False ,
166
206
device = torch .device ("meta" ),
207
+ dtype = module .weight .dtype ,
167
208
)
168
209
linear .set_weight_and_bias (module .weight , module .bias )
169
210
linear .quantize_weight ()
@@ -194,18 +235,29 @@ def cast_to_float8_e4m3_inference(
194
235
"""
195
236
if tensor_already_casted_to_fp8 (inpt_tensor ):
196
237
return inpt_tensor
238
+
239
+ # For input tensors + AxisWise we calculate scales along rows
240
+ dim = None
241
+ if scaling_granularity == ScalingGranularity .AxisWise :
242
+ dim = 1
243
+
197
244
scale = (
198
245
static_quantization_scale
199
246
if static_quantization_scale is not None
200
247
else tensor_to_scale (
201
- inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
248
+ inpt_tensor ,
249
+ e4m3_dtype ,
250
+ scaling_granularity ,
251
+ dim = dim ,
252
+ reduce_amax = reduce_amax ,
202
253
)
203
254
)
204
255
return Float8Tensor .to_float8 (
205
256
inpt_tensor ,
206
257
scale ,
207
258
e4m3_dtype ,
208
259
mm_config = mm_config ,
260
+ scaling_granularity = scaling_granularity ,
209
261
)
210
262
211
263
@@ -215,6 +267,7 @@ def quantize_to_float8(
215
267
* ,
216
268
skip_fqn_list : Optional [List [str ]] = None ,
217
269
use_fast_accum : bool = True ,
270
+ scaling_granularity : Optional [ScalingGranularity ] = None ,
218
271
) -> Optional [nn .Module ]:
219
272
"""
220
273
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +281,7 @@ def quantize_to_float8(
228
281
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229
282
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230
283
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
284
+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231
285
232
286
Returns:
233
287
nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +291,8 @@ def quantize_to_float8(
237
291
"""
238
292
return swap_linear_layers (
239
293
module ,
240
- lambda m : Float8InferenceLinear .from_float (m , quant_config , use_fast_accum ),
294
+ lambda m : Float8InferenceLinear .from_float (
295
+ m , quant_config , use_fast_accum , scaling_granularity
296
+ ),
241
297
skip_fqn_list = skip_fqn_list ,
242
298
)
0 commit comments