Skip to content

Commit ec64182

Browse files
danielvegamyhreDaniel Vega-Myhre
and
Daniel Vega-Myhre
authored
[float8nocompile] Simplified Float8Linear implementation which only supports dynamic tensorwise scaling (#1429)
* float8nocompile: add simplified implementation of float8linear which only supports dynamic tensorwise scaling * address comments --------- Co-authored-by: Daniel Vega-Myhre <danvm@fb.com>
1 parent 33d57af commit ec64182

File tree

5 files changed

+224
-11
lines changed

5 files changed

+224
-11
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
examples/
21
kernels/
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
5+
convert_to_float8_nocompile_training,
6+
)
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8+
9+
if not TORCH_VERSION_AT_LEAST_2_5:
10+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
11+
12+
# create model and sample input
13+
m = (
14+
nn.Sequential(
15+
nn.Linear(32, 32),
16+
)
17+
.bfloat16()
18+
.cuda()
19+
)
20+
x = torch.randn(32, 32, device="cuda", dtype=torch.bfloat16)
21+
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
22+
23+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
24+
print("calling convert_to_float8_nocompile_training")
25+
convert_to_float8_nocompile_training(m)
26+
print("finished convert_to_float8_nocompile_training")
27+
28+
for i in range(10):
29+
print(f"step {i}")
30+
optimizer.zero_grad()
31+
y = m(x)
32+
y.sum().backward()
33+
optimizer.step()

torchao/prototype/float8nocompile/float8nocompile_linear.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,28 @@
55
# LICENSE file in the root directory of this source tree.
66
"""
77
A simple module swap UX for a float8 version of `torch.nn.Linear` which
8-
does not require `torch.compile` to be performant..
8+
does not require `torch.compile` to be performant.
99
"""
10+
from typing import Optional
1011

1112
import torch
1213

14+
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
15+
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
16+
from torchao.float8.float8_linear import manual_float8_matmul_with_args_in_float8
17+
from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic
18+
from torchao.float8.float8_tensor import (
19+
GemmInputRole,
20+
hp_tensor_and_scale_to_float8,
21+
LinearMMConfig,
22+
ScaledMMConfig,
23+
)
24+
from torchao.float8.float8_utils import tensor_to_scale
25+
26+
from torchao.prototype.float8nocompile.float8nocompile_scaling_utils import (
27+
hp_tensor_to_float8nocompile_dynamic,
28+
)
29+
1330

1431
class Float8LinearNoCompile(torch.nn.Linear):
1532
"""
@@ -19,4 +36,111 @@ class Float8LinearNoCompile(torch.nn.Linear):
1936
Note: this is **prototype** and not suitable for production use.
2037
"""
2138

22-
pass
39+
def __init__(self, *args, **kwargs):
40+
"""
41+
Additional arguments on top of `torch.nn.Linear`'s arguments:
42+
* `config`: Float8LinearConfig
43+
"""
44+
config = kwargs.pop("config")
45+
emulate = config.emulate
46+
super().__init__(*args, **kwargs)
47+
48+
self.config = config
49+
50+
self.linear_mm_config = LinearMMConfig(
51+
# output
52+
ScaledMMConfig(
53+
emulate,
54+
self.config.gemm_config_output.use_fast_accum,
55+
False,
56+
self.config.pad_inner_dim,
57+
),
58+
# grad_input
59+
ScaledMMConfig(
60+
emulate,
61+
self.config.gemm_config_grad_input.use_fast_accum,
62+
False,
63+
self.config.pad_inner_dim,
64+
),
65+
# grad_weight
66+
ScaledMMConfig(
67+
emulate,
68+
self.config.gemm_config_grad_weight.use_fast_accum,
69+
False,
70+
self.config.pad_inner_dim,
71+
),
72+
)
73+
74+
def forward(self, input: torch.Tensor) -> torch.Tensor:
75+
# TODO(danielvegamyhre): replace conversions with triton kernels
76+
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
77+
input_fp8 = self.cast_input_to_float8(input)
78+
weight_fp8_t = self.cast_weight_to_float8_t(self.weight)
79+
80+
# compute fp8 matmul
81+
output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t)
82+
83+
# cast grad_output to float8_e5m2 during backward
84+
return self.cast_output_to_float8_in_bw(output)
85+
86+
def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
87+
# Duplicate the autocast logic for F.linear, so that the output
88+
# of our module has the right original precision
89+
if torch.is_autocast_enabled():
90+
# For now, hardcode to GPU's autocast dtype
91+
# if we need CPU support in the future, we can add it
92+
autocast_dtype = torch.get_autocast_gpu_dtype()
93+
input = input.to(autocast_dtype)
94+
95+
# TODO(danielvegamyhre): implement this fn in scaling_utils with call to triton kernel
96+
return hp_tensor_to_float8nocompile_dynamic(
97+
input,
98+
self.config.cast_config_input.target_dtype,
99+
self.linear_mm_config,
100+
gemm_input_role=GemmInputRole.INPUT,
101+
)
102+
103+
def cast_weight_to_float8_t(
104+
self,
105+
weight: torch.Tensor,
106+
) -> torch.Tensor:
107+
# TODO(danielvegamyhre): replace conversion with triton kernel
108+
weight_fp8 = hp_tensor_to_float8nocompile_dynamic(
109+
weight,
110+
self.config.cast_config_weight.target_dtype,
111+
self.linear_mm_config,
112+
gemm_input_role=GemmInputRole.WEIGHT,
113+
)
114+
return weight_fp8.t()
115+
116+
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
117+
# casts grad_output to float8_e5m2 for backward
118+
# TODO(danielvegamyhre): replace conversion with triton kernel
119+
return NoopFwToFloat8BwDynamic.apply(
120+
output,
121+
self.linear_mm_config,
122+
self.config.cast_config_grad_output.target_dtype,
123+
)
124+
125+
@classmethod
126+
def from_float(cls, mod):
127+
"""
128+
Create an nn.Linear with fp8 compute from a regular nn.Linear
129+
130+
Args:
131+
mod (torch.nn.Linear): nn.Linear to convert
132+
config (Optional[Float8LinearConfig]): configuration for conversion to float8
133+
"""
134+
config = Float8LinearConfig()
135+
with torch.device("meta"):
136+
new_mod = cls(
137+
mod.in_features,
138+
mod.out_features,
139+
bias=False,
140+
config=config,
141+
)
142+
new_mod.weight = mod.weight
143+
new_mod.bias = mod.bias
144+
145+
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
146+
return new_mod

torchao/prototype/float8nocompile/float8nocompile_linear_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from torchao.float8.config import Float8LinearConfig
1313
from torchao.float8.float8_linear_utils import swap_linear_layers
1414

15-
from torchao.prototype.float8nocompile.float8_linear import Float8LinearNoCompile
15+
from torchao.prototype.float8nocompile.float8nocompile_linear import (
16+
Float8LinearNoCompile,
17+
)
1618

1719
log = logging.getLogger(__name__)
1820
log.addHandler(logging.NullHandler())
@@ -22,7 +24,6 @@ def convert_to_float8_nocompile_training(
2224
module: nn.Module,
2325
*,
2426
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
25-
config: Float8LinearConfig = None,
2627
) -> nn.Module:
2728
"""
2829
Swaps `torch.nn.Linear` in `module` with `Float8LinearNoCompile`.
@@ -37,12 +38,7 @@ def convert_to_float8_nocompile_training(
3738
Returns:
3839
nn.Module: The modified module with swapped linear layers.
3940
"""
40-
if config is None:
41-
config = Float8LinearConfig()
42-
from_float = lambda m: Float8LinearNoCompile.from_float(
43-
m,
44-
config=config,
45-
)
41+
from_float = lambda m: Float8LinearNoCompile.from_float(m)
4642
return swap_linear_layers(
4743
module,
4844
from_float,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Utilities for scaling high precision tensors to float8.
9+
"""
10+
11+
from typing import Optional
12+
13+
import torch
14+
15+
from torchao.float8.config import ScalingGranularity
16+
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
17+
from torchao.float8.float8_tensor import (
18+
_ToFloat8ConstrFunc,
19+
Float8Tensor,
20+
GemmInputRole,
21+
LinearMMConfig,
22+
)
23+
from torchao.float8.float8_utils import tensor_to_scale
24+
25+
# avoid division by zero when calculating scale
26+
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
27+
EPS = 1e-12
28+
29+
30+
def hp_tensor_to_float8nocompile_dynamic(
31+
hp_tensor: torch.Tensor,
32+
float8_dtype: torch.dtype,
33+
linear_mm_config: LinearMMConfig,
34+
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
35+
) -> Float8Tensor:
36+
"""
37+
Given a high precision tensor `hp_tensor`,
38+
scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
39+
40+
Args:
41+
hp_tensor: the tensor to convert
42+
float8_dtype: the float8 dtype to use
43+
linear_mm_config: Defines the configuration for the scaled_mm for
44+
the 3 fwd/bwd gemms of linear
45+
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
46+
the 3 fwd/bwd gemms of linear
47+
"""
48+
# TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
49+
# torch.compile and eager show different numerics for 1.0 / float32,
50+
# upcast to float64 to ensure same numeric between compile and eager
51+
amax = torch.max(torch.abs(hp_tensor)).to(torch.float64)
52+
scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
53+
scale = scale.to(torch.float32) # scale must be fp32
54+
return _ToFloat8ConstrFunc.apply(
55+
hp_tensor,
56+
scale,
57+
float8_dtype,
58+
linear_mm_config,
59+
gemm_input_role,
60+
None,
61+
)

0 commit comments

Comments
 (0)