16
16
17
17
from float8_experimental .config import Float8LinearConfig , ScalingType
18
18
19
- from float8_experimental .float8_dynamic_utils import (
19
+ from float8_experimental .float8_scaling_utils import (
20
+ _maybe_initialize_amaxes_scales_for_float8_cast ,
21
+ cast_to_float8_delayed ,
20
22
cast_to_float8_e4m3_dynamic ,
21
- cast_to_float8_e5m2_dynamic_bw ,
23
+ NoopFwToFloat8E5M2BwDelayed ,
24
+ NoopFwToFloat8E5M2BwDynamic ,
22
25
)
23
26
24
27
from float8_experimental .float8_tensor import (
25
28
Float8Tensor ,
26
29
GemmInputRole ,
27
30
LinearMMConfig ,
28
31
ScaledMMConfig ,
29
- to_fp8_no_autograd ,
30
32
)
31
33
32
- from float8_experimental .float8_utils import (
33
- amax_history_to_scale ,
34
- e4m3_dtype ,
35
- e5m2_dtype ,
36
- tensor_to_amax ,
37
- )
34
+ from float8_experimental .float8_utils import e4m3_dtype , e5m2_dtype , tensor_to_amax
38
35
39
36
from float8_experimental .fsdp_utils import (
40
37
WeightWithDelayedFloat8CastTensor ,
41
38
WeightWithDynamicFloat8CastTensor ,
42
39
)
43
40
44
41
45
- def _maybe_initialize_amaxes_scales_for_float8_cast (
46
- x ,
47
- cur_amax ,
48
- amax_history ,
49
- scale ,
50
- scale_fn_name ,
51
- float8_dtype ,
52
- is_initialized ,
53
- reduce_amax ,
54
- ):
55
- """
56
- If x is about to be cast to `float8` and the amax buffers are not initialized,
57
- initializes them inplace.
58
- """
59
- if is_initialized :
60
- return
61
- with torch .no_grad ():
62
- # Note: we need to enable distributed reduction here in order
63
- # to match numerics between single GPU and multi GPU code for
64
- # activations and gradients
65
- new_amax = tensor_to_amax (x , reduce_amax = reduce_amax )
66
- cur_amax .fill_ (new_amax )
67
- amax_history [0 ] = new_amax
68
- new_scale = amax_history_to_scale (
69
- amax_history , float8_dtype , x .dtype , scale_fn_name
70
- )
71
- scale .copy_ (new_scale )
72
-
73
-
74
42
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
75
43
@torch ._dynamo .allow_in_graph
76
44
class manual_float8_matmul (torch .autograd .Function ):
@@ -127,66 +95,6 @@ def backward(ctx, grad_output_fp8):
127
95
return grad_input , grad_weight .t ()
128
96
129
97
130
- @torch ._dynamo .allow_in_graph
131
- class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
132
- """
133
- Forward: no-op
134
- Backward: convert to float8_e5m2, initialize if needed
135
- """
136
-
137
- @staticmethod
138
- def forward (
139
- ctx ,
140
- tensor ,
141
- fp8_amax_grad_output ,
142
- fp8_amax_history_grad_output ,
143
- fp8_scale_grad_output ,
144
- scale_fn_name ,
145
- is_amax_initialized ,
146
- linear_mm_config : LinearMMConfig ,
147
- ):
148
- ctx .save_for_backward (
149
- fp8_amax_grad_output , fp8_amax_history_grad_output , fp8_scale_grad_output
150
- )
151
- ctx .scale_fn_name = scale_fn_name
152
- ctx .is_amax_initialized = is_amax_initialized
153
- ctx .linear_mm_config = linear_mm_config
154
- return tensor
155
-
156
- @staticmethod
157
- def backward (ctx , go ):
158
- (
159
- fp8_amax_grad_output ,
160
- fp8_amax_history_grad_output ,
161
- fp8_scale_grad_output ,
162
- ) = ctx .saved_tensors
163
- scale_fn_name = ctx .scale_fn_name
164
- is_amax_initialized = ctx .is_amax_initialized
165
-
166
- _maybe_initialize_amaxes_scales_for_float8_cast (
167
- go ,
168
- fp8_amax_grad_output ,
169
- fp8_amax_history_grad_output ,
170
- fp8_scale_grad_output ,
171
- scale_fn_name ,
172
- e5m2_dtype ,
173
- is_amax_initialized ,
174
- reduce_amax = True ,
175
- )
176
-
177
- fp8_amax_grad_output .fill_ (tensor_to_amax (go ))
178
-
179
- res = to_fp8_no_autograd (
180
- go ,
181
- fp8_scale_grad_output ,
182
- e5m2_dtype ,
183
- linear_mm_config = ctx .linear_mm_config ,
184
- gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
185
- )
186
- empty_grads = None , None , None , None , None , None
187
- return res , * empty_grads
188
-
189
-
190
98
class Float8Linear (torch .nn .Linear ):
191
99
"""
192
100
Note: this is **not** a public API and is only intended to be used
@@ -352,7 +260,7 @@ def cast_input_to_float8(
352
260
is_amax_initialized ,
353
261
reduce_amax = True ,
354
262
)
355
- input_fp8 = Float8Tensor . to_float8 (
263
+ input_fp8 = cast_to_float8_delayed (
356
264
input ,
357
265
self .fp8_scale_input ,
358
266
e4m3_dtype ,
@@ -384,7 +292,7 @@ def cast_weight_to_float8(
384
292
reduce_amax = False ,
385
293
)
386
294
387
- weight_fp8 = Float8Tensor . to_float8 (
295
+ weight_fp8 = cast_to_float8_delayed (
388
296
weight ,
389
297
self .fp8_scale_weight ,
390
298
e4m3_dtype ,
@@ -407,7 +315,7 @@ def cast_weight_to_float8(
407
315
def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
408
316
if self .scaling_type_grad_output is ScalingType .DELAYED :
409
317
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
410
- output = NoopFwToFloat8E5M2Bw .apply (
318
+ output = NoopFwToFloat8E5M2BwDelayed .apply (
411
319
output ,
412
320
self .fp8_amax_grad_output ,
413
321
self .fp8_amax_history_grad_output ,
@@ -418,7 +326,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
418
326
)
419
327
else :
420
328
assert self .scaling_type_grad_output is ScalingType .DYNAMIC
421
- output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
329
+ output = NoopFwToFloat8E5M2BwDynamic . apply (output , self .linear_mm_config )
422
330
return output
423
331
424
332
def float8_pre_forward (self , input ):
0 commit comments