Skip to content

Commit 31238c8

Browse files
committed
Dynamic Float8 benchmarking llama (#1017)
1 parent 92feafa commit 31238c8

File tree

5 files changed

+82
-12
lines changed

5 files changed

+82
-12
lines changed

torchao/_models/llama/benchmarks.sh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
99
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
1010
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
1111

12-
13-
1412
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
1513
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
1614
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
@@ -19,6 +17,14 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
1917
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2018
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
2119

20+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
21+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
22+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
24+
# Runs on H100, float8 is not supported on CUDA arch < 8.9
25+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt
26+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt
27+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt
2228

2329
# OTHER BENCHMARKS
2430

@@ -58,4 +64,4 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
5864
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
5965
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
6066
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
61-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
67+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

torchao/_models/llama/eval.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
fpx_weight_only,
2121
uintx_weight_only,
2222
unwrap_tensor_subclass,
23+
float8_weight_only,
24+
float8_dynamic_activation_float8_weight,
25+
float8_static_activation_float8_weight,
2326
)
27+
from torchao.quantization.observer import PerRow, PerTensor
2428
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
29+
from torchao._models.llama.model import prepare_inputs_for_model
2530

2631
from tokenizer import get_tokenizer
2732
import time
28-
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
29-
from torchao._models.llama.model import prepare_inputs_for_model, TransformerBlock
3033
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
3134

3235
def run_evaluation(
@@ -55,19 +58,16 @@ def run_evaluation(
5558
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
5659
assert tokenizer_path.is_file(), str(tokenizer_path)
5760
# Load Model and Tokenizer
58-
5961
print("Loading model ...")
6062
t0 = time.time()
6163
model = _load_model(checkpoint_path, "cpu", precision)
6264

6365
if max_length is None:
6466
max_length = model.config.block_size
65-
6667
device_sync(device=device) # MKG
6768
print(f"Time to load model: {time.time() - t0:.02f} seconds")
6869
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
6970

70-
7171
if quantization:
7272
if "int8wo" in quantization:
7373
quantize_(model, int8_weight_only())
@@ -100,6 +100,9 @@ def run_evaluation(
100100
from torchao.dtypes import MarlinSparseLayoutType
101101
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
102102
if "int4wo" in quantization and "gptq" in quantization:
103+
# avoid circular imports
104+
from torchao._models._eval import InputRecorder
105+
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
103106
groupsize=int(quantization.split("-")[-2])
104107
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
105108
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
@@ -122,9 +125,24 @@ def run_evaluation(
122125
else:
123126
if not TORCH_VERSION_AT_LEAST_2_5:
124127
unwrap_tensor_subclass(model)
128+
if "float8wo" in quantization:
129+
quantize_(model, float8_weight_only())
130+
if "float8dq" in quantization:
131+
granularity = str(quantization.split("-")[-1])
132+
if granularity=="tensor":
133+
granularity = PerTensor()
134+
elif granularity=="row":
135+
granularity = PerRow()
136+
else:
137+
if granularity=="float8dq":
138+
granularity = PerTensor()
139+
else:
140+
raise ValueError(f"Unknown granularity {granularity}")
141+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
125142
if "autoround" in quantization:
126143
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
127144
from transformers import AutoTokenizer
145+
from torchao._models.llama.model import TransformerBlock
128146

129147
_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
130148
# parse args from quantization string:
@@ -182,6 +200,9 @@ def run_evaluation(
182200
if compile:
183201
model = torch.compile(model, mode="max-autotune", fullgraph=True)
184202
with torch.no_grad():
203+
print("Running evaluation ...")
204+
# avoid circular imports
205+
from torchao._models._eval import TransformerEvalWrapper
185206
TransformerEvalWrapper(
186207
model=model.to(device),
187208
tokenizer=tokenizer,
@@ -209,7 +230,8 @@ def run_evaluation(
209230
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
210231
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
211232
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
212-
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>"
233+
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
234+
"float8wo, float8dq, float8saq"
213235
),
214236
)
215237
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')

torchao/_models/llama/evals.sh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,13 @@ python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quanti
1111
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
1212
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
1313
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
14-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2'
15-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'
14+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
15+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
16+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
17+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
18+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
19+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row
20+
21+
# Testing on additional tasks
22+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge'
23+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2'

torchao/_models/llama/generate.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,11 @@ def main(
210210
fpx_weight_only,
211211
uintx_weight_only,
212212
autoquant,
213-
unwrap_tensor_subclass
213+
unwrap_tensor_subclass,
214+
float8_weight_only,
215+
float8_dynamic_activation_float8_weight,
214216
)
217+
from torchao.quantization.observer import PerTensor, PerRow
215218
if "int8wo" in quantization:
216219
quantize_(model, int8_weight_only())
217220
if "int8dq" in quantization:
@@ -243,6 +246,17 @@ def main(
243246
dtype = _NBITS_TO_DTYPE[nbits]
244247
group_size = int(_quant_args[2])
245248
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
249+
if "float8wo" in quantization:
250+
quantize_(model, float8_weight_only())
251+
if "float8dq" in quantization:
252+
granularity = str(quantization.split("-")[-1])
253+
if granularity=="tensor":
254+
granularity = PerTensor()
255+
elif granularity=="row":
256+
granularity = PerRow()
257+
else:
258+
granularity = PerTensor()
259+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
246260
if "autoquant" in quantization:
247261
if "autoquant-int4" == quantization:
248262
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)

torchao/quantization/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
2121
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
2222
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
2323

24+
Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.
25+
26+
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
27+
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
28+
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
29+
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
30+
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
31+
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
32+
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
33+
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
34+
2435
note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
2536

2637
For int4 we make heavy use of [tinygemm](https://github.com/pytorch/ao/blob/cb3bd8c674f2123af232a0231b5e38ddafa756a8/torchao/dtypes/aqt.py#L526) of `torch.ops.aten._weight_int4pack_mm` to bitpack into a layout optimized for tensor cores
@@ -121,6 +132,15 @@ from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtenso
121132
change_linear_weights_to_int8_dqtensors(model)
122133
```
123134

135+
#### A8W8 Float8 Dynamic Quantization
136+
137+
```python
138+
# for torch 2.4+
139+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140+
from torchao.quantization.observer import PerTensor
141+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142+
```
143+
124144
#### A16W6 Floating Point WeightOnly Quantization
125145

126146
```python

0 commit comments

Comments
 (0)