Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit dbd5d02

Browse files
committed
Update base for Update on "bring back torch.autograd.Function for float8 matmul"
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 90f73c8 commit dbd5d02

17 files changed

+138
-225
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.config import (
18-
Float8LinearConfig,
19-
Float8TensorCastConfig,
20-
TensorScalingType,
21-
)
17+
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
2218
from float8_experimental.float8_linear import Float8Linear
2319
from float8_experimental.float8_linear_utils import (
2420
linear_requires_sync,
@@ -107,15 +103,13 @@ def main(
107103
device = "cuda"
108104
print(f"Compile is set to | {compile}")
109105

110-
scaling_type_input = TensorScalingType(scaling_type_input)
111-
scaling_type_weight = TensorScalingType(scaling_type_weight)
112-
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
106+
scaling_type_input = ScalingType(scaling_type_input)
107+
scaling_type_weight = ScalingType(scaling_type_weight)
108+
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
113109
config = Float8LinearConfig(
114-
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
115-
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
116-
cast_config_grad_output=Float8TensorCastConfig(
117-
scaling_type=scaling_type_grad_output
118-
),
110+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
111+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
112+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
119113
)
120114

121115
# LLaMa 2 70B single-node weight shapes

benchmarks/bench_multi_gpu.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
import torch.multiprocessing as mp
1515
import torch.nn as nn
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.config import (
18-
Float8LinearConfig,
19-
Float8TensorCastConfig,
20-
TensorScalingType,
21-
)
17+
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
2218
from float8_experimental.float8_linear_utils import (
2319
convert_to_float8_training,
2420
sync_float8_amax_and_scale_history,
@@ -33,11 +29,9 @@
3329
lr = 0.01
3430

3531
config = Float8LinearConfig(
36-
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
37-
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
38-
cast_config_grad_output=Float8TensorCastConfig(
39-
scaling_type=TensorScalingType.DELAYED
40-
),
32+
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
33+
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
34+
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
4135
)
4236

4337

benchmarks/profile_linear_float8.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from float8_experimental.config import (
22-
Float8LinearConfig,
23-
Float8TensorCastConfig,
24-
TensorScalingType,
25-
)
21+
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
2622
from float8_experimental.float8_linear_utils import (
2723
convert_to_float8_training,
2824
linear_requires_sync,
@@ -217,15 +213,13 @@ def main(
217213
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
218214
assert dtype_filter in ("both", "float8", "bfloat16")
219215

220-
scaling_type_input = TensorScalingType(scaling_type_input)
221-
scaling_type_weight = TensorScalingType(scaling_type_weight)
222-
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
216+
scaling_type_input = ScalingType(scaling_type_input)
217+
scaling_type_weight = ScalingType(scaling_type_weight)
218+
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
223219
config = Float8LinearConfig(
224-
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
225-
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
226-
cast_config_grad_output=Float8TensorCastConfig(
227-
scaling_type=scaling_type_grad_output
228-
),
220+
cast_config_input=CastConfig(scaling_type=scaling_type_input),
221+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
222+
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
229223
)
230224
scaling_repr = "_".join(
231225
[

float8_experimental/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.config import (
8+
CastConfig,
89
DelayedScalingConfig,
910
Float8GemmConfig,
1011
Float8LinearConfig,
11-
Float8TensorCastConfig,
12-
TensorScalingType,
12+
ScalingType,
1313
)
1414
from float8_experimental.float8_linear import Float8Linear
1515
from float8_experimental.float8_linear_utils import (
@@ -33,10 +33,10 @@
3333
__all__ = [
3434
# configuration
3535
"DelayedScalingConfig",
36-
"TensorScalingType",
36+
"ScalingType",
3737
"Float8GemmConfig",
3838
"Float8LinearConfig",
39-
"Float8TensorCastConfig",
39+
"CastConfig",
4040
# top level UX
4141
"convert_to_float8_training",
4242
"linear_requires_sync",

float8_experimental/config.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,26 @@
88
from dataclasses import dataclass
99

1010

11-
class TensorScalingType(enum.Enum):
11+
# TODO(future): consider renaming to ScalingType
12+
class ScalingType(enum.Enum):
1213
DELAYED = "delayed"
1314
DYNAMIC = "dynamic"
1415

1516
def short_str(self):
16-
if self is TensorScalingType.DELAYED:
17+
if self is ScalingType.DELAYED:
1718
return "del"
1819
else:
19-
assert self is TensorScalingType.DYNAMIC
20+
assert self is ScalingType.DYNAMIC
2021
return "dyn"
2122

2223

2324
@dataclass(frozen=True)
24-
class Float8TensorCastConfig:
25+
class CastConfig:
2526
"""
2627
Configuration for casting a single tensor to float8
2728
"""
2829

29-
scaling_type: TensorScalingType = TensorScalingType.DYNAMIC
30+
scaling_type: ScalingType = ScalingType.DYNAMIC
3031

3132

3233
@dataclass(frozen=True)
@@ -74,9 +75,9 @@ class Float8LinearConfig:
7475
#
7576
# Per-tensor configuration for `input`, `weight`, `grad_output`
7677
#
77-
cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig()
78-
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
79-
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()
78+
cast_config_input: CastConfig = CastConfig()
79+
cast_config_weight: CastConfig = CastConfig()
80+
cast_config_grad_output: CastConfig = CastConfig()
8081

8182
#
8283
# Per-gemm configuration for gemms calculating `output`, `grad_input` and

float8_experimental/float8_linear.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616

17-
from float8_experimental.config import Float8LinearConfig, TensorScalingType
17+
from float8_experimental.config import Float8LinearConfig, ScalingType
1818

1919
from float8_experimental.float8_dynamic_utils import (
2020
cast_to_float8_e4m3_dynamic,
@@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs):
159159
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
160160
# Convenience flag to skip code related to delayed scaling
161161
self.has_any_delayed_scaling = (
162-
self.scaling_type_input is TensorScalingType.DELAYED
163-
or self.scaling_type_weight is TensorScalingType.DELAYED
164-
or self.scaling_type_grad_output is TensorScalingType.DELAYED
162+
self.scaling_type_input is ScalingType.DELAYED
163+
or self.scaling_type_weight is ScalingType.DELAYED
164+
or self.scaling_type_grad_output is ScalingType.DELAYED
165165
)
166166

167167
self.config = config
@@ -284,7 +284,7 @@ def cast_input_to_float8(
284284
autocast_dtype = torch.get_autocast_gpu_dtype()
285285
input = input.to(autocast_dtype)
286286

287-
if self.scaling_type_input is TensorScalingType.DELAYED:
287+
if self.scaling_type_input is ScalingType.DELAYED:
288288
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
289289
_maybe_initialize_amaxes_scales_for_float8_cast(
290290
input,
@@ -305,14 +305,14 @@ def cast_input_to_float8(
305305
gemm_input_role=GemmInputRole.INPUT,
306306
)
307307
else:
308-
assert self.scaling_type_input is TensorScalingType.DYNAMIC
308+
assert self.scaling_type_input is ScalingType.DYNAMIC
309309
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
310310
return input_fp8
311311

312312
def cast_weight_to_float8(
313313
self, weight: torch.Tensor, is_amax_initialized: bool
314314
) -> torch.Tensor:
315-
if self.scaling_type_weight is TensorScalingType.DELAYED:
315+
if self.scaling_type_weight is ScalingType.DELAYED:
316316
if isinstance(self.weight, Float8Tensor): # cast by FSDP
317317
weight_fp8 = self.weight
318318
else:
@@ -337,7 +337,7 @@ def cast_weight_to_float8(
337337
gemm_input_role=GemmInputRole.WEIGHT,
338338
)
339339
else:
340-
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
340+
assert self.scaling_type_weight is ScalingType.DYNAMIC
341341
if isinstance(self.weight, Float8Tensor): # cast by FSDP
342342
weight_fp8 = self.weight
343343
else:
@@ -349,7 +349,7 @@ def cast_weight_to_float8(
349349
return weight_fp8
350350

351351
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
352-
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
352+
if self.scaling_type_grad_output is ScalingType.DELAYED:
353353
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
354354
output = NoopFwToFloat8E5M2Bw.apply(
355355
output,
@@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
361361
self.linear_mm_config,
362362
)
363363
else:
364-
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
364+
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
365365
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
366366
return output
367367

@@ -448,17 +448,15 @@ def from_float(
448448
# 2. buffers need to be already created for the delayed scaling version
449449
# of the weight wrapper to be initialized
450450
if config.enable_fsdp_float8_all_gather:
451-
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
451+
if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC:
452452
new_mod.weight = torch.nn.Parameter(
453453
WeightWithDynamicFloat8CastTensor(
454454
new_mod.weight,
455455
new_mod.linear_mm_config,
456456
)
457457
)
458458
else:
459-
assert (
460-
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED
461-
)
459+
assert config.cast_config_weight.scaling_type is ScalingType.DELAYED
462460
new_mod.weight = torch.nn.Parameter(
463461
WeightWithDelayedFloat8CastTensor(
464462
new_mod.weight,

float8_experimental/float8_linear_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.distributed as dist
1111
import torch.nn as nn
12-
from float8_experimental.config import Float8LinearConfig, TensorScalingType
12+
from float8_experimental.config import Float8LinearConfig, ScalingType
1313
from float8_experimental.float8_linear import Float8Linear
1414

1515
from float8_experimental.float8_utils import (
@@ -27,9 +27,9 @@ def linear_requires_sync(config: Float8LinearConfig):
2727
"""Returns whether the given linear_type requires sync before forward."""
2828
return any(
2929
[
30-
config.cast_config_input.scaling_type is TensorScalingType.DELAYED,
31-
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED,
32-
config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED,
30+
config.cast_config_input.scaling_type is ScalingType.DELAYED,
31+
config.cast_config_weight.scaling_type is ScalingType.DELAYED,
32+
config.cast_config_grad_output.scaling_type is ScalingType.DELAYED,
3333
]
3434
)
3535

float8_experimental/float8_tensor_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from float8_experimental.config import TensorScalingType
3+
from float8_experimental.config import ScalingType
44
from float8_experimental.float8_dynamic_utils import (
55
cast_to_float8_e4m3_dynamic,
66
cast_to_float8_e5m2_dynamic_bw,
@@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m):
2828
# TODO(future): add support for delayed scaling for activations
2929
# and gradients
3030
return (
31-
m.scaling_type_input == TensorScalingType.DYNAMIC
32-
and m.scaling_type_grad_output == TensorScalingType.DYNAMIC
31+
m.scaling_type_input == ScalingType.DYNAMIC
32+
and m.scaling_type_grad_output == ScalingType.DYNAMIC
3333
)
3434

3535

float8_experimental/fsdp_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
3333
optim.step()
3434
precompute_float8_dynamic_scale_for_fsdp(model)
3535
"""
36-
from float8_experimental.config import TensorScalingType
36+
from float8_experimental.config import ScalingType
3737
from float8_experimental.float8_linear import Float8Linear
3838
from torch.distributed._tensor import DTensor
3939

4040
if any(
41-
isinstance(m, Float8Linear)
42-
and m.scaling_type_weight is TensorScalingType.DELAYED
41+
isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED
4342
for m in module.modules()
4443
):
4544
raise NotImplementedError("Only supports delayed scaling")

0 commit comments

Comments
 (0)