15
15
from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
16
16
from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
17
17
from torchao .float8 .float8_scaling_utils import (
18
- NoopFwToFloat8E5M2BwDelayed ,
19
- NoopFwToFloat8E5M2BwDynamic ,
20
- NoopFwToFloat8E5M2BwStatic ,
18
+ NoopFwToFloat8BwDelayed ,
19
+ NoopFwToFloat8BwDynamic ,
20
+ NoopFwToFloat8BwStatic ,
21
21
_maybe_initialize_amaxes_scales_for_float8_cast ,
22
22
get_maybe_axiswise_dim ,
23
23
hp_tensor_to_float8_delayed ,
32
32
hp_tensor_and_scale_to_float8 ,
33
33
)
34
34
from torchao .float8 .float8_utils import (
35
- e4m3_dtype ,
36
- e5m2_dtype ,
37
35
tensor_to_amax ,
38
36
tensor_to_scale ,
39
37
)
@@ -136,7 +134,7 @@ def forward(
136
134
else :
137
135
input_maybe_fp8 = hp_tensor_to_float8_dynamic (
138
136
input_hp ,
139
- e4m3_dtype ,
137
+ c . cast_config_input . dtype ,
140
138
linear_mm_config ,
141
139
gemm_input_role = GemmInputRole .INPUT ,
142
140
scaling_granularity = c .cast_config_input .scaling_granularity ,
@@ -150,7 +148,7 @@ def forward(
150
148
else :
151
149
weight_maybe_fp8_t = hp_tensor_to_float8_dynamic (
152
150
weight_hp_t ,
153
- e4m3_dtype ,
151
+ c . cast_config_weight . dtype ,
154
152
linear_mm_config ,
155
153
gemm_input_role = GemmInputRole .WEIGHT ,
156
154
scaling_granularity = c .cast_config_weight .scaling_granularity ,
@@ -186,7 +184,7 @@ def backward(ctx, grad_output):
186
184
else :
187
185
grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic (
188
186
grad_output_reshaped ,
189
- e5m2_dtype ,
187
+ c . cast_config_grad_output . dtype ,
190
188
ctx .linear_mm_config ,
191
189
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
192
190
scaling_granularity = c .cast_config_grad_output .scaling_granularity ,
@@ -204,7 +202,7 @@ def backward(ctx, grad_output):
204
202
# the entire tensor.
205
203
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic (
206
204
weight_hp_t ,
207
- e4m3_dtype ,
205
+ c . cast_config_weight_for_grad_input . dtype ,
208
206
ctx .linear_mm_config ,
209
207
gemm_input_role = GemmInputRole .WEIGHT ,
210
208
scaling_granularity = c .cast_config_weight_for_grad_input .scaling_granularity ,
@@ -236,7 +234,7 @@ def backward(ctx, grad_output):
236
234
else :
237
235
grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic (
238
236
grad_output_reshaped ,
239
- e5m2_dtype ,
237
+ c . cast_config_grad_output_for_grad_weight . dtype ,
240
238
ctx .linear_mm_config ,
241
239
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
242
240
scaling_granularity = c .cast_config_grad_output_for_grad_weight .scaling_granularity ,
@@ -250,7 +248,7 @@ def backward(ctx, grad_output):
250
248
else :
251
249
input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic (
252
250
input_hp_reshaped ,
253
- e4m3_dtype ,
251
+ c . cast_config_input_for_grad_weight . dtype ,
254
252
ctx .linear_mm_config ,
255
253
gemm_input_role = GemmInputRole .INPUT ,
256
254
scaling_granularity = c .cast_config_input_for_grad_weight .scaling_granularity ,
@@ -347,11 +345,9 @@ def create_buffers(self):
347
345
# Default values for history buffers, see above TODO
348
346
history_len = self .config .delayed_scaling_config .history_len
349
347
device = self .weight .device
350
- # TODO(future PR): dtype values below don't have the other float8
351
- # flavors, fix it
352
- default_input = torch .finfo (torch .float8_e4m3fn ).max
353
- default_weight = torch .finfo (torch .float8_e4m3fn ).max
354
- default_grad_output = torch .finfo (torch .float8_e5m2 ).max
348
+ default_input = torch .finfo (self .config .cast_config_input .dtype ).max
349
+ default_weight = torch .finfo (self .config .cast_config_weight .dtype ).max
350
+ default_grad_output = torch .finfo (self .config .cast_config_grad_output .dtype ).max
355
351
356
352
# Note: for now, create all the buffers if any are needed, to postpone
357
353
# the work to make the scale and amax syncing and history calculation
@@ -438,29 +434,32 @@ def cast_input_to_float8(
438
434
self .fp8_amax_history_input ,
439
435
self .fp8_scale_input ,
440
436
scale_fn_name ,
441
- e4m3_dtype ,
437
+ self . config . cast_config_input . dtype ,
442
438
is_amax_initialized ,
443
439
reduce_amax = True ,
444
440
)
445
441
input_fp8 = hp_tensor_to_float8_delayed (
446
442
input ,
447
443
self .fp8_scale_input ,
448
- e4m3_dtype ,
444
+ self . config . cast_config_input . dtype ,
449
445
self .fp8_amax_input ,
450
446
linear_mm_config = self .linear_mm_config ,
451
447
gemm_input_role = GemmInputRole .INPUT ,
452
448
)
453
449
elif self .scaling_type_input is ScalingType .DYNAMIC :
454
450
input_fp8 = hp_tensor_to_float8_dynamic (
455
451
input ,
456
- e4m3_dtype ,
452
+ self . config . cast_config_input . dtype ,
457
453
self .linear_mm_config ,
458
454
gemm_input_role = GemmInputRole .INPUT ,
459
455
)
460
456
else :
461
457
assert self .scaling_type_input is ScalingType .STATIC
462
458
input_fp8 = hp_tensor_to_float8_static (
463
- input , self .fp8_static_scale_input , e4m3_dtype , self .linear_mm_config
459
+ input ,
460
+ self .fp8_static_scale_input ,
461
+ self .config .cast_config_input .dtype ,
462
+ self .linear_mm_config ,
464
463
)
465
464
466
465
return input_fp8
@@ -476,14 +475,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
476
475
self .fp8_amax_history_weight ,
477
476
self .fp8_scale_weight ,
478
477
scale_fn_name ,
479
- e4m3_dtype ,
478
+ self . config . cast_config_weight . dtype ,
480
479
self .is_amax_initialized ,
481
480
reduce_amax = True ,
482
481
)
483
482
self .fp8_amax_weight .fill_ (tensor_to_amax (weight ))
484
483
return self .fp8_scale_weight
485
484
elif self .scaling_type_weight is ScalingType .DYNAMIC :
486
- return tensor_to_scale (weight , e4m3_dtype )
485
+ return tensor_to_scale (weight , self . config . cast_config_weight . dtype )
487
486
else :
488
487
assert self .scaling_type_weight is ScalingType .STATIC
489
488
return self .fp8_static_scale_weight
@@ -499,7 +498,7 @@ def cast_weight_to_float8_t(
499
498
weight_fp8 = hp_tensor_and_scale_to_float8 (
500
499
weight ,
501
500
weight_scale ,
502
- e4m3_dtype ,
501
+ self . config . cast_config_weight . dtype ,
503
502
self .linear_mm_config ,
504
503
gemm_input_role = GemmInputRole .WEIGHT ,
505
504
)
@@ -514,23 +513,29 @@ def cast_weight_to_original_t(self, weight: torch.Tensor):
514
513
def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
515
514
if self .scaling_type_grad_output is ScalingType .DELAYED :
516
515
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
517
- output = NoopFwToFloat8E5M2BwDelayed .apply (
516
+ output = NoopFwToFloat8BwDelayed .apply (
518
517
output ,
519
518
self .fp8_amax_grad_output ,
520
519
self .fp8_amax_history_grad_output ,
521
520
self .fp8_scale_grad_output ,
522
521
scale_fn_name ,
523
522
self .is_amax_initialized ,
524
523
self .linear_mm_config ,
524
+ self .config .cast_config_grad_output .dtype ,
525
525
)
526
526
elif self .scaling_type_grad_output is ScalingType .DYNAMIC :
527
- output = NoopFwToFloat8E5M2BwDynamic .apply (output , self .linear_mm_config )
527
+ output = NoopFwToFloat8BwDynamic .apply (
528
+ output ,
529
+ self .linear_mm_config ,
530
+ self .config .cast_config_grad_output .dtype ,
531
+ )
528
532
else :
529
533
assert self .scaling_type_grad_output is ScalingType .STATIC
530
- output = NoopFwToFloat8E5M2BwStatic .apply (
534
+ output = NoopFwToFloat8BwStatic .apply (
531
535
output ,
532
536
self .fp8_static_scale_grad_output ,
533
537
self .linear_mm_config ,
538
+ self .config .cast_config_grad_output .dtype ,
534
539
)
535
540
return output
536
541
@@ -547,19 +552,16 @@ def float8_post_forward(self):
547
552
return
548
553
549
554
def forward_fp8_matmul (self , input : torch .Tensor ) -> torch .Tensor :
550
- has_any_axiswise_scaling = (
551
- self .config .cast_config_input .scaling_granularity
552
- is ScalingGranularity .AXISWISE
553
- or self .config .cast_config_weight .scaling_granularity
554
- is ScalingGranularity .AXISWISE
555
- or self .config .cast_config_grad_output .scaling_granularity
556
- is ScalingGranularity .AXISWISE
557
- or self .config .cast_config_input_for_grad_weight .scaling_granularity
558
- is ScalingGranularity .AXISWISE
559
- or self .config .cast_config_weight_for_grad_input .scaling_granularity
560
- is ScalingGranularity .AXISWISE
561
- or self .config .cast_config_grad_output_for_grad_weight .scaling_granularity
562
- is ScalingGranularity .AXISWISE
555
+ has_any_axiswise_scaling = any (
556
+ cc .scaling_granularity is ScalingGranularity .AXISWISE
557
+ for cc in [
558
+ self .config .cast_config_input ,
559
+ self .config .cast_config_weight ,
560
+ self .config .cast_config_grad_output ,
561
+ self .config .cast_config_input_for_grad_weight ,
562
+ self .config .cast_config_weight_for_grad_input ,
563
+ self .config .cast_config_grad_output_for_grad_weight ,
564
+ ]
563
565
)
564
566
565
567
if not has_any_axiswise_scaling :
@@ -682,6 +684,7 @@ def from_float(
682
684
WeightWithDynamicFloat8CastTensor (
683
685
new_mod .weight ,
684
686
new_mod .linear_mm_config ,
687
+ new_mod .config .cast_config_weight .dtype ,
685
688
)
686
689
)
687
690
elif config .cast_config_weight .scaling_type is ScalingType .DELAYED :
@@ -692,6 +695,7 @@ def from_float(
692
695
new_mod .fp8_amax_history_weight ,
693
696
new_mod .fp8_scale_weight ,
694
697
new_mod .linear_mm_config ,
698
+ new_mod .config .cast_config_weight .dtype ,
695
699
new_mod .is_amax_initialized ,
696
700
)
697
701
)
@@ -702,6 +706,7 @@ def from_float(
702
706
new_mod .weight ,
703
707
new_mod .fp8_static_scale_weight ,
704
708
new_mod .linear_mm_config ,
709
+ new_mod .config .cast_config_weight .dtype ,
705
710
)
706
711
)
707
712
0 commit comments