You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | -
225
+
| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03%
226
+
| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05%
227
+
228
+
**Important notes**:
229
+
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
230
+
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
231
+
232
+
**Reproducing training benchmarks**
233
+
To reproduce these benchmarks, you can follow these steps:
234
+
235
+
1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
236
+
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
237
+
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
238
+
3. From the `torchao/float8/benchmarking/` directory, you can run the following commands to reproduce the benchmarks above:
0 commit comments