25
25
tensor_already_casted_to_fp8 ,
26
26
to_fp8_no_autograd ,
27
27
)
28
- from float8_experimental .float8_utils import e4m3_dtype , tensor_to_scale
28
+ from float8_experimental .float8_utils import (
29
+ e4m3_dtype ,
30
+ get_supported_granularity ,
31
+ tensor_to_scale ,
32
+ )
33
+
34
+ SUPPORTED_GRANULARITY = get_supported_granularity ()
29
35
30
36
31
37
class ActivationCasting (Enum ):
@@ -75,7 +81,7 @@ def __init__(
75
81
# FP8 specific arguments
76
82
quant_config : QuantConfig ,
77
83
forward_config : ScaledMMConfig ,
78
- scaling_granularity : ScalingGranularity ,
84
+ scaling_granularity : Optional [ ScalingGranularity ] ,
79
85
# nn.Linear arguments
80
86
in_features : int ,
81
87
out_features : int ,
@@ -86,7 +92,26 @@ def __init__(
86
92
# Construct the superclass this will create dummy weights and biases
87
93
super ().__init__ (in_features , out_features , bias , device , dtype )
88
94
self .forward_config = forward_config
89
- self .scaling_granularity = scaling_granularity
95
+ if scaling_granularity is None :
96
+ self .scaling_granularity = (
97
+ ScalingGranularity .AxisWise
98
+ if dtype == torch .bfloat16
99
+ and quant_config .static_quantization_scale is None
100
+ else ScalingGranularity .TensorWise
101
+ )
102
+ else :
103
+ assert (
104
+ scaling_granularity in SUPPORTED_GRANULARITY
105
+ ), f"scaling_granularity must be in { SUPPORTED_GRANULARITY } but got { scaling_granularity } "
106
+ if (
107
+ scaling_granularity == ScalingGranularity .AxisWise
108
+ and dtype != torch .bfloat16
109
+ ):
110
+ raise ValueError (
111
+ "AxisWise scaling granularity is only supported for bfloat16."
112
+ )
113
+ self .scaling_granularity = scaling_granularity
114
+
90
115
self .activation_casting = quant_config .activation_casting
91
116
if self .activation_casting == ActivationCasting .STATIC :
92
117
self .register_buffer (
@@ -101,13 +126,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
101
126
input , self .weight .to_original_precision ()
102
127
)
103
128
129
+ # TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks
130
+ original_m = input .shape [:- 1 ]
131
+ input = input .view (- 1 , input .shape [- 1 ])
132
+
104
133
x_fp8 = cast_to_float8_e4m3_inference (
105
134
input ,
106
135
self .forward_config ,
107
136
static_quantization_scale = self .static_quantization_scale ,
108
137
scaling_granularity = self .scaling_granularity ,
109
138
)
110
- return torch .nn .functional .linear (x_fp8 , self .weight , self .bias )
139
+ return torch .nn .functional .linear (x_fp8 , self .weight , self .bias ).view (
140
+ * original_m , - 1
141
+ )
111
142
112
143
# Builder functions for Float8LinearInference
113
144
def quantize_weight (self , dtype : torch .dtype = e4m3_dtype ) -> None :
@@ -124,7 +155,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
124
155
assert not isinstance (
125
156
self .weight , Float8Tensor
126
157
), "Weight has already been quantized, cannot quantize again."
127
- scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity )
158
+
159
+ # For weight tensors + AxisWise we calculate scales along columns
160
+ dim = None
161
+ if self .scaling_granularity == ScalingGranularity .AxisWise :
162
+ dim = 1
163
+ scale = tensor_to_scale (self .weight , dtype , self .scaling_granularity , dim = dim )
128
164
quantized_weight = to_fp8_no_autograd (
129
165
self .weight , scale , dtype , self .forward_config
130
166
)
@@ -143,19 +179,20 @@ def from_float(
143
179
module : nn .Module ,
144
180
quant_config : QuantConfig ,
145
181
use_fast_accum : bool ,
182
+ scaling_granularity : Optional [ScalingGranularity ],
146
183
) -> "Float8InferenceLinear" :
147
184
"""
148
185
Create an nn.Linear with fp8 compute from another nn.Linear
149
186
150
187
Args:
151
188
mod (torch.nn.Linear): nn.Linear to convert
152
189
quant_config (QuantConfig): Configuration for the weight and activation casting
190
+ use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear.
191
+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
153
192
"""
154
193
forward_config = ScaledMMConfig (
155
194
False , use_fast_accum , pad_inner_dim = config .pad_inner_dim
156
195
)
157
- # TODO: For now hardcode TensorWise scaling
158
- scaling_granularity = ScalingGranularity .TensorWise
159
196
linear = cls (
160
197
quant_config ,
161
198
forward_config ,
@@ -164,6 +201,7 @@ def from_float(
164
201
module .out_features ,
165
202
False ,
166
203
device = torch .device ("meta" ),
204
+ dtype = module .weight .dtype ,
167
205
)
168
206
linear .set_weight_and_bias (module .weight , module .bias )
169
207
linear .quantize_weight ()
@@ -194,18 +232,29 @@ def cast_to_float8_e4m3_inference(
194
232
"""
195
233
if tensor_already_casted_to_fp8 (inpt_tensor ):
196
234
return inpt_tensor
235
+
236
+ # For input tensors + AxisWise we calculate scales along rows
237
+ dim = None
238
+ if scaling_granularity == ScalingGranularity .AxisWise :
239
+ dim = 1
240
+
197
241
scale = (
198
242
static_quantization_scale
199
243
if static_quantization_scale is not None
200
244
else tensor_to_scale (
201
- inpt_tensor , e4m3_dtype , scaling_granularity , reduce_amax = reduce_amax
245
+ inpt_tensor ,
246
+ e4m3_dtype ,
247
+ scaling_granularity ,
248
+ dim = dim ,
249
+ reduce_amax = reduce_amax ,
202
250
)
203
251
)
204
252
return Float8Tensor .to_float8 (
205
253
inpt_tensor ,
206
254
scale ,
207
255
e4m3_dtype ,
208
256
mm_config = mm_config ,
257
+ scaling_granularity = scaling_granularity ,
209
258
)
210
259
211
260
@@ -215,6 +264,7 @@ def quantize_to_float8(
215
264
* ,
216
265
skip_fqn_list : Optional [List [str ]] = None ,
217
266
use_fast_accum : bool = True ,
267
+ scaling_granularity : Optional [ScalingGranularity ] = None ,
218
268
) -> Optional [nn .Module ]:
219
269
"""
220
270
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
@@ -228,6 +278,7 @@ def quantize_to_float8(
228
278
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
229
279
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
230
280
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
281
+ scaling_granularity: The granularity of the scale. See ScalingGranularity for more details.
231
282
232
283
Returns:
233
284
nn.Module: The modified module with applicable Linear layers converted to Float8.
@@ -237,6 +288,8 @@ def quantize_to_float8(
237
288
"""
238
289
return swap_linear_layers (
239
290
module ,
240
- lambda m : Float8InferenceLinear .from_float (m , quant_config , use_fast_accum ),
291
+ lambda m : Float8InferenceLinear .from_float (
292
+ m , quant_config , use_fast_accum , scaling_granularity
293
+ ),
241
294
skip_fqn_list = skip_fqn_list ,
242
295
)
0 commit comments