You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
| Llama3-8b | none (bf16) | per op SAC |6019| 47.65|
288
-
| Llama3-8b | tensorwise | per op SAC |7190| 47.77|
289
-
| Llama3-8b | rowwise | per op SAC |6649| 47.79|
290
-
291
-
In these benchmarks tensorwise scaling achieved ~8% higher tokens/second over rowwise scaling, and ~19.5% higher than the bf16 baseline.
292
-
However, it is important to note that rowwise scaling has been shown to yield improvments in training loss/accuracy due to reduced quantization error, particularly
293
-
when training large models for many steps.
285
+
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over basline
| Llama3-8b | none (bf16) | per op SAC | 47.65| 6019| -
288
+
| Llama3-8b | tensorwise | per op SAC | 47.77| 7190| 19.45%
289
+
| Llama3-8b | rowwise | per op SAC | 47.79| 6649| 10.47%
290
+
291
+
**Important notes**:
292
+
- Speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ((example)[https://pytorch.org/blog/training-using-float8-fsdp2/]).
293
+
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
294
294
295
295
**Reproducing training benchmarks**
296
296
To reproduce these benchmarks, you can follow these steps:
0 commit comments