File tree Expand file tree Collapse file tree 2 files changed +14
-4
lines changed Expand file tree Collapse file tree 2 files changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -128,9 +128,17 @@ def run_evaluation(
128
128
if "float8wo" in quantization :
129
129
quantize_ (model , float8_weight_only ())
130
130
if "float8dq" in quantization :
131
- quantize_ (model , float8_dynamic_activation_float8_weight ())
132
- if "float8saq" in quantization :
133
- quantize_ (model , float8_static_activation_float8_weight ())
131
+ granularity = str (quantization .split ("-" )[- 1 ])
132
+ if granularity == "tensor" :
133
+ granularity = PerTensor ()
134
+ elif granularity == "row" :
135
+ granularity = PerRow ()
136
+ else :
137
+ if granularity == "float8dq" :
138
+ granularity = PerTensor ()
139
+ else :
140
+ raise ValueError (f"Unknown granularity { granularity } " )
141
+ quantize_ (model , float8_dynamic_activation_float8_weight (granularity = granularity ))
134
142
if "autoround" in quantization :
135
143
from torchao .prototype .autoround .autoround_llm import quantize_model_with_autoround_
136
144
from transformers import AutoTokenizer
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma
30
30
| | int4wo-64 | 8.44 | 241.39 | 1019.14 | 7.08 | 4.22 |
31
31
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
32
32
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
33
- | | float8dq (Per Row) | 7.62 | 154.63 | 1161.47 | 11.14 | 7.51 |
33
+ | | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
34
34
35
35
note: Int8 dynamic quantization works best on compute bound models like [ SAM] ( https://github.com/pytorch-labs/segment-anything-fast ) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
36
36
@@ -139,6 +139,8 @@ change_linear_weights_to_int8_dqtensors(model)
139
139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140
140
from torchao.quantization.observer import PerTensor
141
141
quantize_(model, float8_dynamic_activation_float8_weight(granularity = PerTensor()))
142
+ from torchao.quantization.observer import PerTensor
143
+ quantize_(model, float8_dynamic_activation_float8_weight(granularity = PerTensor()))
142
144
143
145
```
144
146
You can’t perform that action at this time.
0 commit comments