Skip to content

Commit ac601d5

Browse files
committed
Update benchmarks.sh
1 parent b16772d commit ac601d5

File tree

3 files changed

+2
-7
lines changed

3 files changed

+2
-7
lines changed

torchao/_models/llama/benchmarks.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
2121
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
2222
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
2323
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
24+
# Runs on H100, float8 is not supported on CUDA arch < 8.9
2425
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8wo --write_result benchmark_results.txt
2526
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-tensor --write_result benchmark_results.txt
2627
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization float8dq-wo --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,7 @@ def main(
255255
elif granularity=="row":
256256
granularity = PerRow()
257257
else:
258-
if granularity=="float8dq":
259-
granularity = PerTensor()
260-
else:
261-
raise ValueError(f"Unknown granularity {granularity}")
258+
granularity = PerTensor()
262259
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
263260
if "autoquant" in quantization:
264261
if "autoquant-int4" == quantization:

torchao/quantization/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,6 @@ change_linear_weights_to_int8_dqtensors(model)
139139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140140
from torchao.quantization.observer import PerTensor
141141
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142-
from torchao.quantization.observer import PerTensor
143-
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
144-
145142
```
146143

147144
#### A16W6 Floating Point WeightOnly Quantization

0 commit comments

Comments
 (0)