15
15
from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
16
16
from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
17
17
from torchao .float8 .float8_scaling_utils import (
18
- NoopFwToFloat8BwDynamic ,
19
18
get_maybe_axiswise_dim ,
20
19
hp_tensor_to_float8_dynamic ,
21
20
)
29
28
from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
30
29
31
30
32
- def _cast_input_to_float8 (
33
- input : torch .Tensor ,
34
- scaling_type_input : ScalingType ,
35
- config : Float8LinearConfig ,
36
- linear_mm_config : LinearMMConfig ,
37
- ) -> torch .Tensor :
38
- # Duplicate the autocast logic for F.linear, so that the output
39
- # of our module has the right original precision
40
- if torch .is_autocast_enabled ():
41
- # For now, hardcode to GPU's autocast dtype
42
- # if we need CPU support in the future, we can add it
43
- autocast_dtype = torch .get_autocast_gpu_dtype ()
44
- input = input .to (autocast_dtype )
45
-
46
- if tensor_already_casted_to_fp8 (input ):
47
- input_fp8 = input
48
- else :
49
- assert scaling_type_input is ScalingType .DYNAMIC
50
- input_fp8 = hp_tensor_to_float8_dynamic (
51
- input ,
52
- config .cast_config_input .target_dtype ,
53
- linear_mm_config ,
54
- gemm_input_role = GemmInputRole .INPUT ,
55
- )
56
- return input_fp8
57
-
58
-
59
31
def _get_weight_scale (
60
32
weight : torch .Tensor ,
61
33
scaling_type_weight : ScalingType ,
@@ -85,21 +57,6 @@ def _cast_weight_to_float8_t(
85
57
return weight_fp8 .t ()
86
58
87
59
88
- def _cast_output_to_float8_in_bw (
89
- output : torch .Tensor ,
90
- scaling_type_grad_output ,
91
- linear_mm_config : LinearMMConfig ,
92
- config : Float8LinearConfig ,
93
- ) -> torch .Tensor :
94
- assert scaling_type_grad_output is ScalingType .DYNAMIC
95
- output = NoopFwToFloat8BwDynamic .apply (
96
- output ,
97
- linear_mm_config ,
98
- config .cast_config_grad_output .target_dtype ,
99
- )
100
- return output
101
-
102
-
103
60
@torch ._dynamo .allow_in_graph
104
61
class matmul_with_hp_or_float8_args (torch .autograd .Function ):
105
62
"""
@@ -329,6 +286,14 @@ def __init__(self, *args, **kwargs):
329
286
)
330
287
331
288
def forward (self , input : torch .Tensor ) -> torch .Tensor :
289
+ # Duplicate the autocast logic for F.linear, so that the output
290
+ # of our module has the right original precision
291
+ if torch .is_autocast_enabled ():
292
+ # For now, hardcode to GPU's autocast dtype
293
+ # if we need CPU support in the future, we can add it
294
+ autocast_dtype = torch .get_autocast_gpu_dtype ()
295
+ input = input .to (autocast_dtype )
296
+
332
297
has_any_axiswise_scaling = any (
333
298
cc .scaling_granularity is ScalingGranularity .AXISWISE
334
299
for cc in [
@@ -341,18 +306,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
341
306
]
342
307
)
343
308
344
- input_maybe_fp8 = input
345
309
weight_maybe_fp8_t = self .weight .t ()
346
310
347
311
# TODO(future PR): check for axiswise scaling for input, weight,
348
312
# grad_output separately instead of together
349
313
if not has_any_axiswise_scaling :
350
- input_fp8 = _cast_input_to_float8 (
351
- input ,
352
- self .scaling_type_input ,
353
- self .config ,
354
- self .linear_mm_config ,
355
- )
356
314
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
357
315
# weight_scale should be saved.
358
316
weight_scale = _get_weight_scale (
@@ -375,25 +333,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
375
333
weight_scale ,
376
334
)
377
335
378
- input_maybe_fp8 = input_fp8
379
336
weight_maybe_fp8_t = weight_fp8_t
380
337
381
338
output = matmul_with_hp_or_float8_args .apply (
382
- input_maybe_fp8 ,
339
+ input ,
383
340
weight_maybe_fp8_t ,
384
341
self .linear_mm_config ,
385
342
self .config ,
386
343
)
387
344
388
- if not has_any_axiswise_scaling :
389
- # Cast grad_output to float8_e5m2 during backward
390
- output = _cast_output_to_float8_in_bw (
391
- output ,
392
- self .scaling_type_grad_output ,
393
- self .linear_mm_config ,
394
- self .config ,
395
- )
396
-
397
345
if self .bias is not None :
398
346
output = output + self .bias .to (output .dtype )
399
347
return output
0 commit comments