Skip to content

Commit 7e99780

Browse files
committed
Fixes
1 parent 1e0e357 commit 7e99780

File tree

4 files changed

+26
-32
lines changed

4 files changed

+26
-32
lines changed

torchao/_models/llama/eval.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
float8_dynamic_activation_float8_weight,
2525
float8_static_activation_float8_weight,
2626
)
27+
from torchao.quantization.observer import PerRow, PerTensor
2728
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
29+
from torchao._models.llama.model import prepare_inputs_for_model
2830

2931
from tokenizer import get_tokenizer
3032
import time
@@ -56,33 +58,17 @@ def run_evaluation(
5658
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
5759
assert tokenizer_path.is_file(), str(tokenizer_path)
5860
# Load Model and Tokenizer
59-
6061
print("Loading model ...")
6162
t0 = time.time()
6263
model = _load_model(checkpoint_path, "cpu", precision)
6364

6465
if max_length is None:
6566
max_length = model.config.block_size
66-
print('Load model successfully')
6767
device_sync(device=device) # MKG
6868
print(f"Time to load model: {time.time() - t0:.02f} seconds")
6969
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
70-
print('Run completed until tokenizer')
7170

7271
if quantization:
73-
from torchao.quantization.quant_api import (
74-
quantize_,
75-
int4_weight_only,
76-
int8_weight_only,
77-
int8_dynamic_activation_int8_weight,
78-
fpx_weight_only,
79-
uintx_weight_only,
80-
unwrap_tensor_subclass,
81-
float8_weight_only,
82-
float8_dynamic_activation_float8_weight,
83-
)
84-
from torchao.quantization.observer import PerRow, PerTensor
85-
print('Quantization imports completed')
8672
if "int8wo" in quantization:
8773
quantize_(model, int8_weight_only())
8874
if "int8dq" in quantization:
@@ -117,7 +103,6 @@ def run_evaluation(
117103
# avoid circular imports
118104
from torchao._models._eval import InputRecorder
119105
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
120-
from torchao._models.llama.model import prepare_inputs_for_model
121106
groupsize=int(quantization.split("-")[-2])
122107
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
123108
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"

torchao/_models/llama/evals.sh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
1111
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
1212
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
1313
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
14-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
15-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
16-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row
14+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
15+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
16+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8dq
17+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
18+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo #7.60
19+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor #7.62
20+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row #7.62
1721
# --tasks 'mmlu' 'truthfulqa_mc2'
1822
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'

torchao/_models/llama/generate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,11 @@ def main(
210210
fpx_weight_only,
211211
uintx_weight_only,
212212
autoquant,
213-
unwrap_tensor_subclass
213+
unwrap_tensor_subclass,
214+
float8_weight_only,
215+
float8_dynamic_activation_float8_weight,
214216
)
217+
from torchao.quantization.observer import PerTensor, PerRow
215218
if "int8wo" in quantization:
216219
quantize_(model, int8_weight_only())
217220
if "int8dq" in quantization:

torchao/quantization/README.md

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
2020
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
2121
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
2222
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
23-
| Llama-3.1-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 |
24-
| | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 |
25-
| | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 |
26-
| | fp6 | 7.661 | 161.58 | 910.02 | 7.72 | 5.63 |
27-
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
28-
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
29-
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
30-
| | float8wo | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
31-
| | float8dq (PerTensor) | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
32-
| | float8dq (Per Row) | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
23+
24+
Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.
25+
26+
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
27+
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
28+
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
29+
| | int8wo | 7.56 | 198.85 | 1495.41 | 11.05 | 7.52 |
30+
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
31+
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
32+
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
33+
| | float8dq (Per Row) | 7.62 | 154.63 | 1161.47 | 11.14 | 7.51 |
3334

3435
note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
3536

@@ -136,7 +137,8 @@ change_linear_weights_to_int8_dqtensors(model)
136137
```python
137138
# for torch 2.4+
138139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
139-
quantize_(model, float8_dynamic_activation_float8_weight())
140+
from torchao.quantization.observer import PerTensor
141+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
140142

141143
```
142144

0 commit comments

Comments
 (0)