Skip to content

Commit 8c81863

Browse files
[float8] add perf benchmarks for float8 training with rowwise + tensorwise scaling (#1793)
1 parent 2b84efc commit 8c81863

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

torchao/float8/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,42 @@ python test/float8/test_fsdp2/test_fsdp2.py
202202
# make sure to turn on torch.compile to get the best performance
203203
./benchmarks/float8/bench_linear_float8.py -o ../tmp/test.txt --compile
204204
```
205+
206+
### Training benchmarks
207+
208+
[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise
209+
and tensorwise scaling. The training benchmarks were all run using:
210+
211+
- Single-node training on 8xH100 GPUs
212+
- Batch size 1
213+
- Sequence length 8192
214+
- Steps 100
215+
- `torch.compile`
216+
- FSDP2
217+
- pytorch version: `2.7.0a0+gitb98af95`
218+
- torchao version: `0.10.0+git890e0ac8`
219+
- torchtitan version: `0.0.2`
220+
221+
222+
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
223+
| ------------- | ---------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
224+
| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | -
225+
| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03%
226+
| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05%
227+
228+
**Important notes**:
229+
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
230+
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
231+
232+
**Reproducing training benchmarks**
233+
To reproduce these benchmarks, you can follow these steps:
234+
235+
1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
236+
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
237+
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
238+
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
239+
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./float8_training_benchmark.sh`
240+
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./float8_training_benchmark.sh`
241+
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
242+
243+
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.

0 commit comments

Comments
 (0)