Skip to content

Commit 5776bc0

Browse files
committed
Add new evaluation metrics
1 parent 26e790d commit 5776bc0

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

test/integration/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
AQInt8WeightOnlyQuantizedLinearWeight2,
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
75-
75+
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
)
7777
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7878
import os
@@ -98,6 +98,7 @@
9898
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
9999

100100
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
101+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
101102

102103
def _int8wo_api(mod):
103104
if TORCH_VERSION_AT_LEAST_2_4:
@@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
744745
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
745746
)
746747

748+
@parameterized.expand(COMMON_DEVICE_DTYPE)
749+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
750+
@unittest.skipIf(not is_H100, "Need H100 to run")
751+
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
752+
self._test_lin_weight_subclass_impl(
753+
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
754+
)
755+
747756
@parameterized.expand(COMMON_DEVICE_DTYPE)
748757
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
749758
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/_models/llama/evals.sh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

3-
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
4-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
5-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
3+
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
4+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
5+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
66

7-
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
8-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantizatio autoround-cpu # auto-round w/o quant_lm_head
9-
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu-1 # auto-round w/ quant_lm_head
7+
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
8+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
9+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
10+
11+
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
12+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
13+
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
14+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'mmlu' 'truthfulqa_mc2'
15+
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoquant --tasks 'winogrande' 'arc_challenge'

0 commit comments

Comments
 (0)