Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit db0ca31

Browse files
committed
[wip] make all 3 gemms in Float8Linear configurable
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 39cb928 Pull Request resolved: #315
1 parent 7f0d6bb commit db0ca31

File tree

11 files changed

+430
-185
lines changed

11 files changed

+430
-185
lines changed

float8_experimental/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
8+
from float8_experimental.float8_tensor import (
9+
Float8Tensor,
10+
GemmInputRole,
11+
LinearMMConfig,
12+
ScaledMMConfig,
13+
)
914

1015
# Needed to load Float8Tensor with weights_only = True
1116
from torch.serialization import add_safe_globals
1217

13-
add_safe_globals([Float8Tensor, ScaledMMConfig])
18+
add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig])
1419

1520
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_dynamic_utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from float8_experimental.float8_tensor import (
1010
Float8Tensor,
11-
ScaledMMConfig,
11+
GemmInputRole,
12+
LinearMMConfig,
1213
tensor_already_casted_to_fp8,
1314
to_fp8_no_autograd,
1415
)
@@ -26,9 +27,9 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
2627
def forward(
2728
ctx,
2829
tensor,
29-
mm_config: ScaledMMConfig,
30+
linear_mm_config: LinearMMConfig,
3031
):
31-
ctx.mm_config = mm_config
32+
ctx.linear_mm_config = linear_mm_config
3233
return tensor
3334

3435
@staticmethod
@@ -37,21 +38,34 @@ def backward(ctx, gradY):
3738
return gradY, None
3839
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
3940
fp8_tensor = to_fp8_no_autograd(
40-
gradY, gradY_scale, e5m2_dtype, mm_config=ctx.mm_config
41+
gradY,
42+
gradY_scale,
43+
e5m2_dtype,
44+
linear_mm_config=ctx.linear_mm_config,
45+
gemm_input_role=GemmInputRole.DL_DY,
4146
)
4247
return fp8_tensor, None
4348

4449

4550
def cast_to_float8_e4m3_dynamic(
46-
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
51+
inpt_tensor: torch.Tensor,
52+
linear_mm_config: LinearMMConfig,
53+
reduce_amax: bool = False,
54+
gemm_input_role: GemmInputRole = GemmInputRole.X,
4755
) -> Float8Tensor:
4856
if tensor_already_casted_to_fp8(inpt_tensor):
4957
return inpt_tensor
5058
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
51-
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
59+
return Float8Tensor.to_float8(
60+
inpt_tensor,
61+
scale,
62+
e4m3_dtype,
63+
linear_mm_config=linear_mm_config,
64+
gemm_input_role=gemm_input_role,
65+
)
5266

5367

5468
def cast_to_float8_e5m2_dynamic_bw(
55-
gradY: torch.Tensor, mm_config: ScaledMMConfig
69+
gradY: torch.Tensor, linear_mm_config: LinearMMConfig
5670
) -> torch.Tensor:
57-
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
71+
return NoopFwToFloat8E5M2Bw.apply(gradY, linear_mm_config)

float8_experimental/float8_linear.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from float8_experimental.float8_tensor import (
2525
Float8Tensor,
26+
GemmInputRole,
27+
LinearMMConfig,
2628
ScaledMMConfig,
2729
to_fp8_no_autograd,
2830
)
@@ -85,12 +87,12 @@ def forward(
8587
fp8_scale_dL_dY,
8688
scale_fn_name,
8789
is_amax_initialized,
88-
mm_config: ScaledMMConfig,
90+
linear_mm_config: LinearMMConfig,
8991
):
9092
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
9193
ctx.scale_fn_name = scale_fn_name
9294
ctx.is_amax_initialized = is_amax_initialized
93-
ctx.mm_config = mm_config
95+
ctx.linear_mm_config = linear_mm_config
9496
return tensor
9597

9698
@staticmethod
@@ -113,7 +115,11 @@ def backward(ctx, go):
113115
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
114116

115117
res = to_fp8_no_autograd(
116-
go, fp8_scale_dL_dY, e5m2_dtype, mm_config=ctx.mm_config
118+
go,
119+
fp8_scale_dL_dY,
120+
e5m2_dtype,
121+
linear_mm_config=ctx.linear_mm_config,
122+
gemm_input_role=GemmInputRole.DL_DY,
117123
)
118124
empty_grads = None, None, None, None, None, None
119125
return res, *empty_grads
@@ -192,12 +198,18 @@ def __init__(self, *args, **kwargs):
192198

193199
self.create_buffers()
194200

195-
# Defines the behavior of the matmul in the forward and backward pass
196-
self.forward_config = ScaledMMConfig(
197-
emulate, True if not emulate else False, False, config.pad_inner_dim
198-
)
199-
self.backward_config = ScaledMMConfig(
200-
emulate, False, False, config.pad_inner_dim
201+
# TODO(future): user level configuration of gemms
202+
self.linear_mm_config = LinearMMConfig(
203+
# x
204+
ScaledMMConfig(
205+
emulate, True if not emulate else False, False, config.pad_inner_dim
206+
),
207+
# w
208+
ScaledMMConfig(
209+
emulate, True if not emulate else False, False, config.pad_inner_dim
210+
),
211+
# dL_dY
212+
ScaledMMConfig(emulate, False, False, config.pad_inner_dim),
201213
)
202214

203215
# Note: is_amax_initialized is not a buffer to avoid data dependent
@@ -308,11 +320,12 @@ def cast_x_to_float8(
308320
self.fp8_scale_x,
309321
e4m3_dtype,
310322
self.fp8_amax_x,
311-
self.forward_config,
323+
linear_mm_config=self.linear_mm_config,
324+
gemm_input_role=GemmInputRole.X,
312325
)
313326
else:
314327
assert self.scaling_type_x is TensorScalingType.DYNAMIC
315-
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
328+
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config)
316329
return x_fp8
317330

318331
def cast_w_to_float8(
@@ -339,14 +352,17 @@ def cast_w_to_float8(
339352
self.fp8_scale_w,
340353
e4m3_dtype,
341354
self.fp8_amax_w,
342-
self.forward_config,
355+
linear_mm_config=self.linear_mm_config,
356+
gemm_input_role=GemmInputRole.W,
343357
)
344358
else:
345359
assert self.scaling_type_w is TensorScalingType.DYNAMIC
346360
if isinstance(self.weight, Float8Tensor): # cast by FSDP
347361
w_fp8 = self.weight
348362
else:
349-
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
363+
w_fp8 = cast_to_float8_e4m3_dynamic(
364+
self.weight, self.linear_mm_config, gemm_input_role=GemmInputRole.W
365+
)
350366
return w_fp8
351367

352368
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
@@ -359,11 +375,11 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
359375
self.fp8_scale_dL_dY,
360376
scale_fn_name,
361377
self.is_amax_initialized,
362-
self.backward_config,
378+
self.linear_mm_config,
363379
)
364380
else:
365381
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
366-
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
382+
y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config)
367383
return y
368384

369385
def float8_pre_forward(self, x):
@@ -457,7 +473,7 @@ def from_float(
457473
new_mod.weight = torch.nn.Parameter(
458474
WeightWithDynamicFloat8CastTensor(
459475
new_mod.weight,
460-
new_mod.forward_config,
476+
new_mod.linear_mm_config,
461477
)
462478
)
463479
else:
@@ -468,7 +484,7 @@ def from_float(
468484
new_mod.fp8_amax_w,
469485
new_mod.fp8_amax_history_w,
470486
new_mod.fp8_scale_w,
471-
new_mod.forward_config,
487+
new_mod.linear_mm_config,
472488
new_mod.is_amax_initialized,
473489
)
474490
)

0 commit comments

Comments
 (0)