Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit aced22f

Browse files
committed
Update on "Add rowwise scaling to Float8Inference module"
# Summary # Performance - Need to investigate the Rowwise dynamic case, I would think this should be faster than TensorWise dynamic ```Shell Benchmark Results: +--------------------------+-------------+ | Variant | Time (μs) | +==========================+=============+ | BF16 | 2540.56 | +--------------------------+-------------+ | FP8 Dynamic | 1512.96 | +--------------------------+-------------+ | FP8 Static | 1363.75 | +--------------------------+-------------+ | FP8 Weight Only | 2774.22 | +--------------------------+-------------+ | FP8 Dynamic AxisWise | 1510.82 | +--------------------------+-------------+ | FP8 Static AxisWise | 1438.92 | +--------------------------+-------------+ | FP8 Weight Only AxisWise | 2762.88 | +--------------------------+-------------+ Comparison Results: +--------------------------+-------------+-------------------+---------------+ | Variant | Time (μs) | Speedup vs BF16 | MAE vs BF16 | +==========================+=============+===================+===============+ | BF16 | 2540.56 | 1.00x | 0 | +--------------------------+-------------+-------------------+---------------+ | FP8 Dynamic | 1512.96 | 1.68x | 0.00543213 | +--------------------------+-------------+-------------------+---------------+ | FP8 Static | 1363.75 | 1.86x | 0.00546265 | +--------------------------+-------------+-------------------+---------------+ | FP8 Weight Only | 2774.22 | 0.92x | 0.00379944 | +--------------------------+-------------+-------------------+---------------+ | FP8 Dynamic AxisWise | 1510.82 | 1.68x | 0.00543213 | +--------------------------+-------------+-------------------+---------------+ | FP8 Static AxisWise | 1438.92 | 1.77x | 0.00546265 | +--------------------------+-------------+-------------------+---------------+ | FP8 Weight Only AxisWise | 2762.88 | 0.92x | 0.00379944 | +--------------------------+-------------+-------------------+---------------+ ``` ### Numerics Using this pytorch/ao#446 TensorWise Dynamic scaling: ``` Shell +------------+--------------------------------------------+ | Task | Metrics | +============+============================================+ | winogrande | +-----------------+----------+ | | | | acc,none | 0.735596 | | | | +-----------------+----------+ | | | | acc_stderr,none | 0.012395 | | | | +-----------------+----------+ | +------------+--------------------------------------------+ | wikitext | +-----------------------------+----------+ | | | | bits_per_byte,none | 0.538637 | | | | +-----------------------------+----------+ | | | | bits_per_byte_stderr,none | N/A | | | | +-----------------------------+----------+ | | | | byte_perplexity,none | 1.452600 | | | | +-----------------------------+----------+ | | | | byte_perplexity_stderr,none | N/A | | | | +-----------------------------+----------+ | | | | word_perplexity,none | 7.363215 | | | | +-----------------------------+----------+ | | | | word_perplexity_stderr,none | N/A | | | | +-----------------------------+----------+ | +------------+--------------------------------------------+ ``` AxisWise Dynamic Scaling ``` Shell +------------+--------------------------------------------+ | Task | Metrics | +============+============================================+ | winogrande | +-----------------+----------+ | | | | acc,none | 0.735596 | | | | +-----------------+----------+ | | | | acc_stderr,none | 0.012395 | | | | +-----------------+----------+ | +------------+--------------------------------------------+ | wikitext | +-----------------------------+----------+ | | | | bits_per_byte,none | 0.538637 | | | | +-----------------------------+----------+ | | | | bits_per_byte_stderr,none | N/A | | | | +-----------------------------+----------+ | | | | byte_perplexity,none | 1.452600 | | | | +-----------------------------+----------+ | | | | byte_perplexity_stderr,none | N/A | | | | +-----------------------------+----------+ | | | | word_perplexity,none | 7.363215 | | | | +-----------------------------+----------+ | | | | word_perplexity_stderr,none | N/A | | | | +-----------------------------+----------+ | +------------+--------------------------------------------+ ``` [ghstack-poisoned]
2 parents 0ec7ada + d7cedf2 commit aced22f

19 files changed

+815
-876
lines changed

README.md

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
This is an early version of a library for accelerating training with float8 in native PyTorch
44
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
5-
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
6-
``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
5+
The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
6+
and composable with key systems such as autograd, ```torch.compile``` and distributed.
7+
With ``torch.compile`` on, initial results show
78
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
89

9-
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. </em>
10+
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features.</em>
1011

1112
:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
1213
will change rapidly.</em>
@@ -25,7 +26,7 @@ pip install -e .
2526
pip install -e ".[dev]"
2627
```
2728

28-
# User API
29+
# Single GPU User API
2930

3031
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
3132

@@ -37,6 +38,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
3738
from float8_experimental.float8_linear_utils import (
3839
swap_linear_with_float8_linear,
3940
)
41+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4042
from float8_experimental.float8_linear import Float8Linear
4143

4244
# create model
@@ -51,7 +53,18 @@ model = FSDP(model, use_orig_params=True)
5153
# optional: enable torch.compile for improved performance
5254
m = torch.compile(m)
5355

54-
# train/finetune (not shown)
56+
# toy training loop
57+
for _ in range(N_ITER):
58+
optimizer.zero_grad()
59+
y = m(x)
60+
y.sum().backward()
61+
optimizer.step()
62+
63+
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
64+
# this method is optional but is highly recommended for performance
65+
# it calcuclates scales for all parameters in a single all-reduce
66+
precompute_float8_dynamic_scale_for_fsdp(model)
67+
5568
```
5669

5770
## float8 linear with delayed scaling
@@ -71,7 +84,7 @@ m = Model(...)
7184
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
7285
# type
7386
swap_linear_with_float8_linear(
74-
m,
87+
m,
7588
Float8Linear,
7689
scaling_type_x=TensorScalingType.DELAYED,
7790
scaling_type_w=TensorScalingType.DELAYED,
@@ -101,30 +114,32 @@ for _ in range(N_ITER):
101114
optimizer.step()
102115
```
103116

104-
# 🧭 Code Organization
117+
# Multi GPU User API
105118

106-
* `float8_experimental/float8_linear.py`
107-
- `Float8Linear` (main user facing entry point for Float8Linear)
108-
* `float8_experimental/float8_tensor.py`
109-
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
110-
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass
119+
We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html),
120+
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
121+
on using `float8_experimental` in a distributed setting.
111122

112123
# Testing
113124

114125
```bash
115126
# run single-GPU unit tests
116127
pytest test/test_base.py
117128

118-
# run a single-GPU integration test on SAM
119-
pytest test/test_sam.py
120-
121129
# run single-GPU compile tests
122130
pytest test/test_compile.py
131+
132+
# run single-GPU numerics integration tests
133+
pytest test/test_numerics_integration.py
134+
123135
# run a two-GPU integration test on FSDP
124136
./test/test_fsdp.sh
125137

126-
# run integration tests for TP/SP (outdated)
127-
./test/test_tp.sh
138+
# run integration tests on the DTensor TP/SP integration
139+
./test/test_dtensor.sh
140+
141+
# run integration tests on the FSDP2 integration
142+
python test/test_fsdp2/test_fsdp2_eager.py
128143

129144
# run all of these tests
130145
./test/test_everything.sh

benchmarks/bench_linear_float8.py

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import TensorScalingType
17+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
19-
get_float8_linear,
2019
linear_requires_sync,
21-
LinearType,
2220
sync_float8_amax_and_scale_history,
2321
)
2422
from float8_experimental.float8_tensor import ScaledMMConfig
@@ -69,7 +67,6 @@ class Experiment:
6967
dtype: torch.dtype
7068
compiled: bool
7169
use_fast_accum: bool
72-
linear_type: str
7370
scaling_repr: str
7471

7572
# 3 Times since we are calculating forward backward
@@ -98,7 +95,6 @@ def main(
9895
n_limit: Optional[int] = None,
9996
fast_accum_filter: Optional[bool] = None,
10097
shape_name_filter: Optional[str] = None,
101-
linear_type_filter: Optional[str] = None,
10298
scaling_type_x: str = "delayed",
10399
scaling_type_w: str = "delayed",
104100
scaling_type_dL_dY: str = "delayed",
@@ -123,44 +119,28 @@ def main(
123119
use_fast_accum = [fast_accum_filter]
124120
else:
125121
use_fast_accum = [True, False]
126-
if linear_type_filter is not None:
127-
linear_types = [linear_type_filter]
128-
else:
129-
linear_types = ["delayed", "dynamic"]
130122
if shape_name_filter is not None:
131123
k = shape_name_filter
132124
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
133125
experiment_list: List[Experiment] = []
134126
dtype = torch.bfloat16
135-
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
136-
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
127+
for idx, (fast_accum, (name, (K, N))) in enumerate(
128+
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
137129
):
138130
if n_limit is not None and idx >= n_limit:
139131
break
140132
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
141133
device=device, dtype=dtype
142134
)
143-
linear_type_enum = (
144-
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
145-
)
146135

147-
if linear_type == "delayed":
148-
linear_float8 = get_float8_linear(
149-
linear_type_enum,
150-
copy.deepcopy(linear_ref),
151-
emulate=False,
152-
scaling_type_x=scaling_type_x,
153-
scaling_type_w=scaling_type_w,
154-
scaling_type_dL_dY=scaling_type_dL_dY,
155-
)
156-
scaling_repr = linear_float8.scaling_repr()
157-
else:
158-
linear_float8 = get_float8_linear(
159-
linear_type_enum,
160-
copy.deepcopy(linear_ref),
161-
emulate=False,
162-
)
163-
scaling_repr = None
136+
linear_float8 = Float8Linear.from_float(
137+
copy.deepcopy(linear_ref),
138+
emulate=False,
139+
scaling_type_x=scaling_type_x,
140+
scaling_type_w=scaling_type_w,
141+
scaling_type_dL_dY=scaling_type_dL_dY,
142+
)
143+
scaling_repr = linear_float8.scaling_repr()
164144

165145
if fast_accum:
166146
linear_float8.forward_config = ScaledMMConfig(False, True, False)
@@ -172,19 +152,10 @@ def main(
172152
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
173153
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
174154

175-
if linear_type_enum == LinearType.DELAYED:
176-
177-
def float8_forw_backward():
178-
if linear_requires_sync(
179-
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
180-
):
181-
sync_float8_amax_and_scale_history(linear_float8)
182-
linear_float8(input_tensor).sum().backward()
183-
184-
else:
185-
186-
def float8_forw_backward():
187-
linear_float8(input_tensor).sum().backward()
155+
def float8_forw_backward():
156+
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
157+
sync_float8_amax_and_scale_history(linear_float8)
158+
linear_float8(input_tensor).sum().backward()
188159

189160
def n_times(n, fn, *args, **kwargs):
190161
def wrapper(*args, **kwargs):
@@ -224,7 +195,6 @@ def wrapper(*args, **kwargs):
224195
dtype,
225196
compile,
226197
use_fast_accum=fast_accum,
227-
linear_type=linear_type,
228198
scaling_repr=scaling_repr,
229199
)
230200
print(experiment)
@@ -237,7 +207,6 @@ def wrapper(*args, **kwargs):
237207
"M",
238208
"K",
239209
"N",
240-
"linear_type",
241210
"scaling_repr",
242211
"ref_dtype",
243212
"compiled",
@@ -257,7 +226,6 @@ def wrapper(*args, **kwargs):
257226
experiment.shape[0],
258227
experiment.shape[1],
259228
experiment.shape[2],
260-
experiment.linear_type,
261229
experiment.scaling_repr,
262230
experiment.dtype,
263231
experiment.compiled,
@@ -287,7 +255,6 @@ def wrapper(*args, **kwargs):
287255
[
288256
"name",
289257
"shape",
290-
"linear_type",
291258
"scaling_repr",
292259
"compiled",
293260
"use_fast_accum",
@@ -311,7 +278,6 @@ def invoke_main() -> None:
311278
parser.add_argument("-n", "--n_limit", type=int, required=False)
312279
parser.add_argument("--fast_accum_filter", type=bool, required=False)
313280
parser.add_argument("--shape_name_filter", type=str, required=False)
314-
parser.add_argument("--linear_type_filter", type=str, required=False)
315281
parser.add_argument("--scaling_type_x", type=str, required=False)
316282
parser.add_argument("--scaling_type_w", type=str, required=False)
317283
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
@@ -330,7 +296,6 @@ def invoke_main() -> None:
330296
args.n_limit,
331297
args.fast_accum_filter,
332298
args.shape_name_filter,
333-
args.linear_type_filter,
334299
**kwargs,
335300
)
336301

benchmarks/bench_multi_gpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.multiprocessing as mp
1515
import torch.nn as nn
1616
import torch.utils.benchmark as benchmark
17-
from float8_experimental.float8_linear import Float8Linear
17+
from float8_experimental.float8_linear import TensorScalingType
1818
from float8_experimental.float8_linear_utils import (
1919
swap_linear_with_float8_linear,
2020
sync_float8_amax_and_scale_history,
@@ -65,7 +65,13 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
6565
modules.append(nn.ReLU())
6666
m = nn.Sequential(*modules)
6767
if is_fp8:
68-
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
68+
swap_linear_with_float8_linear(
69+
m,
70+
emulate=False,
71+
scaling_type_x=TensorScalingType.DELAYED,
72+
scaling_type_w=TensorScalingType.DELAYED,
73+
scaling_type_dL_dY=TensorScalingType.DELAYED,
74+
)
6975
return m
7076

7177

benchmarks/profile_linear_float8.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
22-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
21+
from float8_experimental.float8_linear import TensorScalingType
2322
from float8_experimental.float8_linear_utils import (
2423
linear_requires_sync,
25-
LinearType,
2624
swap_linear_with_float8_linear,
2725
sync_float8_amax_and_scale_history,
2826
)
@@ -206,19 +204,25 @@ def profile_function(
206204
def main(
207205
profile_path_prefix: Path,
208206
compile: bool = True,
209-
linear_type: str = "dynamic",
210-
scaling_type_x: str = "delayed",
211-
scaling_type_w: str = "delayed",
212-
scaling_type_dL_dY: str = "delayed",
207+
scaling_type_x: str = "dynamic",
208+
scaling_type_w: str = "dynamic",
209+
scaling_type_dL_dY: str = "dynamic",
213210
model_type: str = "linear",
214211
dtype_filter: str = "both",
215212
):
216213
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
217214
assert dtype_filter in ("both", "float8", "bfloat16")
218215

219-
print(f"Compile is set to | {compile}")
220-
print(f"Using Linear type: | {linear_type}")
221-
print(f"model_type is set to | {model_type}")
216+
scaling_type_x = TensorScalingType(scaling_type_x)
217+
scaling_type_w = TensorScalingType(scaling_type_w)
218+
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
219+
scaling_repr = "_".join(
220+
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
221+
)
222+
223+
print(f"Compile is set to | {compile}")
224+
print(f"model_type is set to | {model_type}")
225+
print(f"scaling_repr is set to | {scaling_repr}")
222226

223227
device = "cuda"
224228
ref_dtype = torch.bfloat16
@@ -249,21 +253,14 @@ def main(
249253

250254
m_ref = m_ref.to(device).to(ref_dtype)
251255

252-
linear_type = LinearType[linear_type.upper()]
253-
linear_cls = (
254-
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
255-
)
256-
extra_kwargs = {}
257-
scaling_type_x = TensorScalingType(scaling_type_x)
258-
scaling_type_w = TensorScalingType(scaling_type_w)
259-
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
260-
if linear_type is LinearType.DELAYED:
261-
extra_kwargs["scaling_type_x"] = scaling_type_x
262-
extra_kwargs["scaling_type_w"] = scaling_type_w
263-
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY
256+
extra_kwargs = {
257+
"scaling_type_x": scaling_type_x,
258+
"scaling_type_w": scaling_type_w,
259+
"scaling_type_dL_dY": scaling_type_dL_dY,
260+
}
264261

265262
m_float8 = copy.deepcopy(m_ref)
266-
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)
263+
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
267264

268265
def ref_forw_backward(x):
269266
out = m_ref(x)
@@ -281,9 +278,7 @@ def float8_forw_backward_wrapper(x):
281278
# inspection of the fw+bw torch.compile without the scale
282279
# syncing code
283280
# TODO(future): make this better
284-
if linear_requires_sync(
285-
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
286-
):
281+
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
287282
with record_function("scale_amax_and_scales"):
288283
sync_amax_history(m_float8)
289284
out = float8_forw(x)
@@ -345,7 +340,9 @@ def float8_forw_backward_wrapper(x):
345340
if dtype_filter != "bfloat16":
346341
# Profile Float8 Model
347342
print("profiling float8")
348-
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
343+
float8_suffix = (
344+
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
345+
)
349346
float8_path = profile_path_prefix + float8_suffix
350347
profile_config = ProfileConfig(
351348
float8_path,

0 commit comments

Comments
 (0)