Skip to content

Commit b16772d

Browse files
committed
Fixes
1 parent f5abc18 commit b16772d

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

torchao/_models/llama/eval.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,17 @@ def run_evaluation(
128128
if "float8wo" in quantization:
129129
quantize_(model, float8_weight_only())
130130
if "float8dq" in quantization:
131-
quantize_(model, float8_dynamic_activation_float8_weight())
132-
if "float8saq" in quantization:
133-
quantize_(model, float8_static_activation_float8_weight())
131+
granularity = str(quantization.split("-")[-1])
132+
if granularity=="tensor":
133+
granularity = PerTensor()
134+
elif granularity=="row":
135+
granularity = PerRow()
136+
else:
137+
if granularity=="float8dq":
138+
granularity = PerTensor()
139+
else:
140+
raise ValueError(f"Unknown granularity {granularity}")
141+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
134142
if "autoround" in quantization:
135143
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
136144
from transformers import AutoTokenizer

torchao/quantization/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma
3030
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
3131
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
3232
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
33-
| | float8dq (Per Row) | 7.62 | 154.63 | 1161.47 | 11.14 | 7.51 |
33+
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
3434

3535
note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
3636

@@ -139,6 +139,8 @@ change_linear_weights_to_int8_dqtensors(model)
139139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140140
from torchao.quantization.observer import PerTensor
141141
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142+
from torchao.quantization.observer import PerTensor
143+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142144

143145
```
144146

0 commit comments

Comments
 (0)