23
23
24
24
from float8_experimental .float8_tensor import (
25
25
Float8Tensor ,
26
+ GemmInputRole ,
27
+ LinearMMConfig ,
26
28
ScaledMMConfig ,
27
29
to_fp8_no_autograd ,
28
30
)
@@ -85,12 +87,12 @@ def forward(
85
87
fp8_scale_dL_dY ,
86
88
scale_fn_name ,
87
89
is_amax_initialized ,
88
- mm_config : ScaledMMConfig ,
90
+ linear_mm_config : LinearMMConfig ,
89
91
):
90
92
ctx .save_for_backward (fp8_amax_dL_dY , fp8_amax_history_dL_dY , fp8_scale_dL_dY )
91
93
ctx .scale_fn_name = scale_fn_name
92
94
ctx .is_amax_initialized = is_amax_initialized
93
- ctx .mm_config = mm_config
95
+ ctx .linear_mm_config = linear_mm_config
94
96
return tensor
95
97
96
98
@staticmethod
@@ -113,7 +115,11 @@ def backward(ctx, go):
113
115
fp8_amax_dL_dY .fill_ (tensor_to_amax (go ))
114
116
115
117
res = to_fp8_no_autograd (
116
- go , fp8_scale_dL_dY , e5m2_dtype , mm_config = ctx .mm_config
118
+ go ,
119
+ fp8_scale_dL_dY ,
120
+ e5m2_dtype ,
121
+ linear_mm_config = ctx .linear_mm_config ,
122
+ gemm_input_role = GemmInputRole .DL_DY ,
117
123
)
118
124
empty_grads = None , None , None , None , None , None
119
125
return res , * empty_grads
@@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):
192
198
193
199
self .create_buffers ()
194
200
195
- # Defines the behavior of the matmul in the forward and backward pass
196
- self .forward_config = ScaledMMConfig (
197
- emulate , True if not emulate else False , False , config .pad_inner_dim
198
- )
199
- self .backward_config = ScaledMMConfig (
200
- emulate , False , False , config .pad_inner_dim
201
+ # TODO(future): user level configuration of gemms
202
+ self .linear_mm_config = LinearMMConfig (
203
+ # x
204
+ ScaledMMConfig (
205
+ emulate , True if not emulate else False , False , config .pad_inner_dim
206
+ ),
207
+ # w
208
+ ScaledMMConfig (
209
+ emulate , True if not emulate else False , False , config .pad_inner_dim
210
+ ),
211
+ # dL_dY
212
+ ScaledMMConfig (emulate , False , False , config .pad_inner_dim ),
201
213
)
202
214
203
215
# Note: is_amax_initialized is not a buffer to avoid data dependent
@@ -308,11 +320,12 @@ def cast_x_to_float8(
308
320
self .fp8_scale_x ,
309
321
e4m3_dtype ,
310
322
self .fp8_amax_x ,
311
- self .forward_config ,
323
+ linear_mm_config = self .linear_mm_config ,
324
+ gemm_input_role = GemmInputRole .X ,
312
325
)
313
326
else :
314
327
assert self .scaling_type_x is TensorScalingType .DYNAMIC
315
- x_fp8 = cast_to_float8_e4m3_dynamic (x , self .forward_config )
328
+ x_fp8 = cast_to_float8_e4m3_dynamic (x , self .linear_mm_config )
316
329
return x_fp8
317
330
318
331
def cast_w_to_float8 (
@@ -339,14 +352,17 @@ def cast_w_to_float8(
339
352
self .fp8_scale_w ,
340
353
e4m3_dtype ,
341
354
self .fp8_amax_w ,
342
- self .forward_config ,
355
+ linear_mm_config = self .linear_mm_config ,
356
+ gemm_input_role = GemmInputRole .W ,
343
357
)
344
358
else :
345
359
assert self .scaling_type_w is TensorScalingType .DYNAMIC
346
360
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
347
361
w_fp8 = self .weight
348
362
else :
349
- w_fp8 = cast_to_float8_e4m3_dynamic (self .weight , self .forward_config )
363
+ w_fp8 = cast_to_float8_e4m3_dynamic (
364
+ self .weight , self .linear_mm_config , gemm_input_role = GemmInputRole .W
365
+ )
350
366
return w_fp8
351
367
352
368
def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
@@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
359
375
self .fp8_scale_dL_dY ,
360
376
scale_fn_name ,
361
377
self .is_amax_initialized ,
362
- self .backward_config ,
378
+ self .linear_mm_config ,
363
379
)
364
380
else :
365
381
assert self .scaling_type_dL_dY is TensorScalingType .DYNAMIC
366
- y = cast_to_float8_e5m2_dynamic_bw (y , self .backward_config )
382
+ y = cast_to_float8_e5m2_dynamic_bw (y , self .linear_mm_config )
367
383
return y
368
384
369
385
def float8_pre_forward (self , x ):
@@ -457,7 +473,7 @@ def from_float(
457
473
new_mod .weight = torch .nn .Parameter (
458
474
WeightWithDynamicFloat8CastTensor (
459
475
new_mod .weight ,
460
- new_mod .forward_config ,
476
+ new_mod .linear_mm_config ,
461
477
)
462
478
)
463
479
else :
@@ -468,7 +484,7 @@ def from_float(
468
484
new_mod .fp8_amax_w ,
469
485
new_mod .fp8_amax_history_w ,
470
486
new_mod .fp8_scale_w ,
471
- new_mod .forward_config ,
487
+ new_mod .linear_mm_config ,
472
488
new_mod .is_amax_initialized ,
473
489
)
474
490
)
0 commit comments