14
14
15
15
from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
16
16
from torchao .float8 .float8_scaling_utils import (
17
- NoopFwToFloat8E5M2BwDelayed ,
18
- NoopFwToFloat8E5M2BwDynamic ,
19
- NoopFwToFloat8E5M2BwStatic ,
17
+ NoopFwToFloat8BwDelayed ,
18
+ NoopFwToFloat8BwDynamic ,
19
+ NoopFwToFloat8BwStatic ,
20
20
_maybe_initialize_amaxes_scales_for_float8_cast ,
21
21
get_maybe_axiswise_dim ,
22
22
hp_tensor_to_float8_delayed ,
31
31
hp_tensor_and_scale_to_float8 ,
32
32
)
33
33
from torchao .float8 .float8_utils import (
34
- e4m3_dtype ,
35
- e5m2_dtype ,
36
34
tensor_to_amax ,
37
35
tensor_to_scale ,
38
36
)
@@ -135,7 +133,7 @@ def forward(
135
133
else :
136
134
input_maybe_fp8 = hp_tensor_to_float8_dynamic (
137
135
input_hp ,
138
- e4m3_dtype ,
136
+ c . cast_config_input . dtype ,
139
137
linear_mm_config ,
140
138
gemm_input_role = GemmInputRole .INPUT ,
141
139
scaling_granularity = c .cast_config_input .scaling_granularity ,
@@ -149,7 +147,7 @@ def forward(
149
147
else :
150
148
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic (
151
149
weight_hp_t ,
152
- e4m3_dtype ,
150
+ c . cast_config_weight . dtype ,
153
151
linear_mm_config ,
154
152
gemm_input_role = GemmInputRole .WEIGHT ,
155
153
scaling_granularity = c .cast_config_weight .scaling_granularity ,
@@ -185,7 +183,7 @@ def backward(ctx, grad_output):
185
183
else :
186
184
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic (
187
185
grad_output_reshaped ,
188
- e5m2_dtype ,
186
+ c . cast_config_grad_output . dtype ,
189
187
ctx .linear_mm_config ,
190
188
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
191
189
scaling_granularity = c .cast_config_grad_output .scaling_granularity ,
@@ -203,7 +201,7 @@ def backward(ctx, grad_output):
203
201
# the entire tensor.
204
202
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic (
205
203
weight_hp_t ,
206
- e4m3_dtype ,
204
+ c . cast_config_weight_for_grad_input . dtype ,
207
205
ctx .linear_mm_config ,
208
206
gemm_input_role = GemmInputRole .WEIGHT ,
209
207
scaling_granularity = c .cast_config_weight_for_grad_input .scaling_granularity ,
@@ -235,7 +233,7 @@ def backward(ctx, grad_output):
235
233
else :
236
234
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic (
237
235
grad_output_reshaped ,
238
- e5m2_dtype ,
236
+ c . cast_config_grad_output_for_grad_weight . dtype ,
239
237
ctx .linear_mm_config ,
240
238
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
241
239
scaling_granularity = c .cast_config_grad_output_for_grad_weight .scaling_granularity ,
@@ -249,7 +247,7 @@ def backward(ctx, grad_output):
249
247
else :
250
248
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic (
251
249
input_hp_reshaped ,
252
- e4m3_dtype ,
250
+ c . cast_config_input_for_grad_weight . dtype ,
253
251
ctx .linear_mm_config ,
254
252
gemm_input_role = GemmInputRole .INPUT ,
255
253
scaling_granularity = c .cast_config_input_for_grad_weight .scaling_granularity ,
@@ -354,11 +352,9 @@ def create_buffers(self):
354
352
# Default values for history buffers, see above TODO
355
353
history_len = self .config .delayed_scaling_config .history_len
356
354
device = self .weight .device
357
- # TODO(future PR): dtype values below don't have the other float8
358
- # flavors, fix it
359
- default_input = torch .finfo (torch .float8_e4m3fn ).max
360
- default_weight = torch .finfo (torch .float8_e4m3fn ).max
361
- default_grad_output = torch .finfo (torch .float8_e5m2 ).max
355
+ default_input = torch .finfo (config .cast_config_input .dtype ).max
356
+ default_weight = torch .finfo (config .cast_config_weight .dtype ).max
357
+ default_grad_output = torch .finfo (config .cast_config_grad_output .dtype ).max
362
358
363
359
# Note: for now, create all the buffers if any are needed, to postpone
364
360
# the work to make the scale and amax syncing and history calculation
@@ -445,29 +441,32 @@ def cast_input_to_float8(
445
441
self .fp8_amax_history_input ,
446
442
self .fp8_scale_input ,
447
443
scale_fn_name ,
448
- e4m3_dtype ,
444
+ self . config . cast_config_input . dtype ,
449
445
is_amax_initialized ,
450
446
reduce_amax = True ,
451
447
)
452
448
input_fp8 = hp_tensor_to_float8_delayed (
453
449
input ,
454
450
self .fp8_scale_input ,
455
- e4m3_dtype ,
451
+ self . config . cast_config_input . dtype ,
456
452
self .fp8_amax_input ,
457
453
linear_mm_config = self .linear_mm_config ,
458
454
gemm_input_role = GemmInputRole .INPUT ,
459
455
)
460
456
elif self .scaling_type_input is ScalingType .DYNAMIC :
461
457
input_fp8 = hp_tensor_to_float8_dynamic (
462
458
input ,
463
- e4m3_dtype ,
459
+ self . config . cast_config_input . dtype ,
464
460
self .linear_mm_config ,
465
461
gemm_input_role = GemmInputRole .INPUT ,
466
462
)
467
463
else :
468
464
assert self .scaling_type_input is ScalingType .STATIC
469
465
input_fp8 = hp_tensor_to_float8_static (
470
- input , self .fp8_static_scale_input , e4m3_dtype , self .linear_mm_config
466
+ input ,
467
+ self .fp8_static_scale_input ,
468
+ self .config .cast_config_input .dtype ,
469
+ self .linear_mm_config ,
471
470
)
472
471
473
472
return input_fp8
@@ -483,14 +482,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
483
482
self .fp8_amax_history_weight ,
484
483
self .fp8_scale_weight ,
485
484
scale_fn_name ,
486
- e4m3_dtype ,
485
+ self . config . cast_config_weight . dtype ,
487
486
self .is_amax_initialized ,
488
487
reduce_amax = True ,
489
488
)
490
489
self .fp8_amax_weight .fill_ (tensor_to_amax (weight ))
491
490
return self .fp8_scale_weight
492
491
elif self .scaling_type_weight is ScalingType .DYNAMIC :
493
- return tensor_to_scale (weight , e4m3_dtype )
492
+ return tensor_to_scale (weight , self . config . cast_config_weight . dtype )
494
493
else :
495
494
assert self .scaling_type_weight is ScalingType .STATIC
496
495
return self .fp8_static_scale_weight
@@ -506,7 +505,7 @@ def cast_weight_to_float8_t(
506
505
weight_fp8 = hp_tensor_and_scale_to_float8 (
507
506
weight ,
508
507
weight_scale ,
509
- e4m3_dtype ,
508
+ self . config . cast_config_weight . dtype ,
510
509
self .linear_mm_config ,
511
510
gemm_input_role = GemmInputRole .WEIGHT ,
512
511
)
@@ -521,23 +520,25 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
521
520
def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
522
521
if self .scaling_type_grad_output is ScalingType .DELAYED :
523
522
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
524
- output = NoopFwToFloat8E5M2BwDelayed .apply (
523
+ output = NoopFwToFloat8BwDelayed .apply (
525
524
output ,
526
525
self .fp8_amax_grad_output ,
527
526
self .fp8_amax_history_grad_output ,
528
527
self .fp8_scale_grad_output ,
529
528
scale_fn_name ,
530
529
self .is_amax_initialized ,
531
530
self .linear_mm_config ,
531
+ self .config .cast_config_grad_output .dtype ,
532
532
)
533
533
elif self .scaling_type_grad_output is ScalingType .DYNAMIC :
534
- output = NoopFwToFloat8E5M2BwDynamic .apply (output , self .linear_mm_config )
534
+ output = NoopFwToFloat8BwDynamic .apply (output , self .linear_mm_config , self . config . cast_config_grad_output . dtype )
535
535
else :
536
536
assert self .scaling_type_grad_output is ScalingType .STATIC
537
- output = NoopFwToFloat8E5M2BwStatic .apply (
537
+ output = NoopFwToFloat8BwStatic .apply (
538
538
output ,
539
539
self .fp8_static_scale_grad_output ,
540
540
self .linear_mm_config ,
541
+ self .config .cast_config_grad_output .dtype ,
541
542
)
542
543
return output
543
544
@@ -563,19 +564,15 @@ def float8_post_forward(self):
563
564
self .amax_and_scale_synced = False
564
565
565
566
def forward_fp8_matmul (self , input : torch .Tensor ) -> torch .Tensor :
566
- has_any_axiswise_scaling = (
567
- self .config .cast_config_input .scaling_granularity
568
- is ScalingGranularity .AXISWISE
569
- or self .config .cast_config_weight .scaling_granularity
570
- is ScalingGranularity .AXISWISE
571
- or self .config .cast_config_grad_output .scaling_granularity
572
- is ScalingGranularity .AXISWISE
573
- or self .config .cast_config_input_for_grad_weight .scaling_granularity
574
- is ScalingGranularity .AXISWISE
575
- or self .config .cast_config_weight_for_grad_input .scaling_granularity
576
- is ScalingGranularity .AXISWISE
577
- or self .config .cast_config_grad_output_for_grad_weight .scaling_granularity
578
- is ScalingGranularity .AXISWISE
567
+ has_any_axiswise_scaling = any (
568
+ cc .scaling_granularity is ScalingGranularity .AXISWISE for cc in [
569
+ self .config .cast_config_input ,
570
+ self .config .cast_config_weight ,
571
+ self .config .cast_config_grad_output ,
572
+ self .config .cast_config_input_for_grad_weight ,
573
+ self .config .cast_config_weight_for_grad_input ,
574
+ self .config .cast_config_grad_output_for_grad_weight ,
575
+ ]
579
576
)
580
577
581
578
if not has_any_axiswise_scaling :
@@ -698,6 +695,7 @@ def from_float(
698
695
WeightWithDynamicFloat8CastTensor (
699
696
new_mod .weight ,
700
697
new_mod .linear_mm_config ,
698
+ new_mod .config .cast_config_weight .dtype ,
701
699
)
702
700
)
703
701
elif config .cast_config_weight .scaling_type is ScalingType .DELAYED :
@@ -708,6 +706,7 @@ def from_float(
708
706
new_mod .fp8_amax_history_weight ,
709
707
new_mod .fp8_scale_weight ,
710
708
new_mod .linear_mm_config ,
709
+ new_mod .config .cast_config_weight .dtype ,
711
710
new_mod .is_amax_initialized ,
712
711
)
713
712
)
@@ -718,6 +717,7 @@ def from_float(
718
717
new_mod .weight ,
719
718
new_mod .fp8_static_scale_weight ,
720
719
new_mod .linear_mm_config ,
720
+ new_mod .config .cast_config_weight .dtype ,
721
721
)
722
722
)
723
723
0 commit comments