Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 96162b3

Browse files
committed
[1/x] clean up casting functions
Summary: This is a start of a cleanup of private casting functions in preparation for rowwise scaling. In this PR: 1. create `float8_scaling_utils.py` to unify functions which take a high precision tensor and return a float8 tensor, taking care of scaling 2. delete `Float8Tensor.to_float8` and move callsites to `ToFloat8ConstrFunc`, since the two functions do the same thing The end result is a slightly cleaner state, future PRs will do more cleanups. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 30f7160 Pull Request resolved: #339
1 parent 4cc99da commit 96162b3

File tree

11 files changed

+310
-274
lines changed

11 files changed

+310
-274
lines changed

benchmarks/bench_padding.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
import fire
55

66
import torch
7-
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
7+
from float8_experimental.float8_tensor import (
8+
GemmInputRole,
9+
LinearMMConfig,
10+
ScaledMMConfig,
11+
ToFloat8ConstrFunc,
12+
)
813
from float8_experimental.float8_utils import pad_tensor_for_matmul
914
from tabulate import tabulate
1015
from torch._inductor.utils import do_bench_using_profiling
@@ -50,9 +55,25 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
5055
b_config = ScaledMMConfig(
5156
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
5257
)
53-
54-
a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
55-
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)
58+
a_config = LinearMMConfig(a_config, a_config, a_config)
59+
b_config = LinearMMConfig(b_config, b_config, b_config)
60+
61+
a_fp8 = ToFloat8ConstrFunc.apply(
62+
A,
63+
scale_a,
64+
fp8_dtype,
65+
None, # amax_buffer
66+
a_config,
67+
GemmInputRole.INPUT,
68+
)
69+
b_fp8 = ToFloat8ConstrFunc.apply(
70+
B,
71+
scale_b,
72+
fp8_dtype,
73+
None, # amax_buffer
74+
b_config,
75+
GemmInputRole.WEIGHT,
76+
)
5677

5778
return a_fp8 @ b_fp8
5879

float8_experimental/float8_dynamic_utils.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

float8_experimental/float8_linear.py

Lines changed: 10 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -16,61 +16,29 @@
1616

1717
from float8_experimental.config import Float8LinearConfig, ScalingType
1818

19-
from float8_experimental.float8_dynamic_utils import (
19+
from float8_experimental.float8_scaling_utils import (
20+
_maybe_initialize_amaxes_scales_for_float8_cast,
21+
cast_to_float8_delayed,
2022
cast_to_float8_e4m3_dynamic,
21-
cast_to_float8_e5m2_dynamic_bw,
23+
NoopFwToFloat8E5M2BwDelayed,
24+
NoopFwToFloat8E5M2BwDynamic,
2225
)
2326

2427
from float8_experimental.float8_tensor import (
2528
Float8Tensor,
2629
GemmInputRole,
2730
LinearMMConfig,
2831
ScaledMMConfig,
29-
to_fp8_no_autograd,
3032
)
3133

32-
from float8_experimental.float8_utils import (
33-
amax_history_to_scale,
34-
e4m3_dtype,
35-
e5m2_dtype,
36-
tensor_to_amax,
37-
)
34+
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
3835

3936
from float8_experimental.fsdp_utils import (
4037
WeightWithDelayedFloat8CastTensor,
4138
WeightWithDynamicFloat8CastTensor,
4239
)
4340

4441

45-
def _maybe_initialize_amaxes_scales_for_float8_cast(
46-
x,
47-
cur_amax,
48-
amax_history,
49-
scale,
50-
scale_fn_name,
51-
float8_dtype,
52-
is_initialized,
53-
reduce_amax,
54-
):
55-
"""
56-
If x is about to be cast to `float8` and the amax buffers are not initialized,
57-
initializes them inplace.
58-
"""
59-
if is_initialized:
60-
return
61-
with torch.no_grad():
62-
# Note: we need to enable distributed reduction here in order
63-
# to match numerics between single GPU and multi GPU code for
64-
# activations and gradients
65-
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
66-
cur_amax.fill_(new_amax)
67-
amax_history[0] = new_amax
68-
new_scale = amax_history_to_scale(
69-
amax_history, float8_dtype, x.dtype, scale_fn_name
70-
)
71-
scale.copy_(new_scale)
72-
73-
7442
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
7543
@torch._dynamo.allow_in_graph
7644
class manual_float8_matmul(torch.autograd.Function):
@@ -127,66 +95,6 @@ def backward(ctx, grad_output_fp8):
12795
return grad_input, grad_weight.t()
12896

12997

130-
@torch._dynamo.allow_in_graph
131-
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
132-
"""
133-
Forward: no-op
134-
Backward: convert to float8_e5m2, initialize if needed
135-
"""
136-
137-
@staticmethod
138-
def forward(
139-
ctx,
140-
tensor,
141-
fp8_amax_grad_output,
142-
fp8_amax_history_grad_output,
143-
fp8_scale_grad_output,
144-
scale_fn_name,
145-
is_amax_initialized,
146-
linear_mm_config: LinearMMConfig,
147-
):
148-
ctx.save_for_backward(
149-
fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output
150-
)
151-
ctx.scale_fn_name = scale_fn_name
152-
ctx.is_amax_initialized = is_amax_initialized
153-
ctx.linear_mm_config = linear_mm_config
154-
return tensor
155-
156-
@staticmethod
157-
def backward(ctx, go):
158-
(
159-
fp8_amax_grad_output,
160-
fp8_amax_history_grad_output,
161-
fp8_scale_grad_output,
162-
) = ctx.saved_tensors
163-
scale_fn_name = ctx.scale_fn_name
164-
is_amax_initialized = ctx.is_amax_initialized
165-
166-
_maybe_initialize_amaxes_scales_for_float8_cast(
167-
go,
168-
fp8_amax_grad_output,
169-
fp8_amax_history_grad_output,
170-
fp8_scale_grad_output,
171-
scale_fn_name,
172-
e5m2_dtype,
173-
is_amax_initialized,
174-
reduce_amax=True,
175-
)
176-
177-
fp8_amax_grad_output.fill_(tensor_to_amax(go))
178-
179-
res = to_fp8_no_autograd(
180-
go,
181-
fp8_scale_grad_output,
182-
e5m2_dtype,
183-
linear_mm_config=ctx.linear_mm_config,
184-
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
185-
)
186-
empty_grads = None, None, None, None, None, None
187-
return res, *empty_grads
188-
189-
19098
class Float8Linear(torch.nn.Linear):
19199
"""
192100
Note: this is **not** a public API and is only intended to be used
@@ -352,7 +260,7 @@ def cast_input_to_float8(
352260
is_amax_initialized,
353261
reduce_amax=True,
354262
)
355-
input_fp8 = Float8Tensor.to_float8(
263+
input_fp8 = cast_to_float8_delayed(
356264
input,
357265
self.fp8_scale_input,
358266
e4m3_dtype,
@@ -384,7 +292,7 @@ def cast_weight_to_float8(
384292
reduce_amax=False,
385293
)
386294

387-
weight_fp8 = Float8Tensor.to_float8(
295+
weight_fp8 = cast_to_float8_delayed(
388296
weight,
389297
self.fp8_scale_weight,
390298
e4m3_dtype,
@@ -407,7 +315,7 @@ def cast_weight_to_float8(
407315
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
408316
if self.scaling_type_grad_output is ScalingType.DELAYED:
409317
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
410-
output = NoopFwToFloat8E5M2Bw.apply(
318+
output = NoopFwToFloat8E5M2BwDelayed.apply(
411319
output,
412320
self.fp8_amax_grad_output,
413321
self.fp8_amax_history_grad_output,
@@ -418,7 +326,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
418326
)
419327
else:
420328
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
421-
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
329+
output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config)
422330
return output
423331

424332
def float8_pre_forward(self, input):

0 commit comments

Comments
 (0)