-
Notifications
You must be signed in to change notification settings - Fork 267
Eval hf models using lm_eval #2179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jainapurva
wants to merge
13
commits into
main
Choose a base branch
from
eval_hf_models
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
d17ffd3
Eval hf models using lm_eval
jainapurva dc355d6
Throughput updates
jainapurva 8c7583b
Updates
jainapurva 2768c17
Add sh script
jainapurva fa24a4f
Add sh script to regenerate numbers
jainapurva 984e518
Add sh script to regenerate numbers wikitext
jainapurva 79f8af6
Updated code for model size
jainapurva c8a591a
Fix readme issues
jainapurva 116bf96
Remove throughput
jainapurva b4fb034
Merge remote-tracking branch 'origin/main' into eval_hf_models
jainapurva 1b27430
Update readme
jainapurva 0a3a2bb
Update readme
jainapurva e14c186
Update readme
jainapurva File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import itertools | ||
import subprocess | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | ||
|
||
from benchmarks.microbenchmarks.utils import string_to_config | ||
from torchao.quantization import * # noqa: F401, F403 | ||
from torchao.quantization.utils import _lm_eval_available | ||
|
||
|
||
def quantize_model_and_save(model_id, quant_config, output_dir="results"): | ||
"""Quantize the model and save it to the output directory.""" | ||
print("Quantizing model with config: ", quant_config) | ||
if quant_config is None: | ||
quantization_config = None | ||
else: | ||
quantization_config = TorchAoConfig(quant_type=quant_config) | ||
quantized_model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
device_map="auto", | ||
torch_dtype=torch.bfloat16, | ||
quantization_config=quantization_config, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
quantized_model.save_pretrained(output_dir, safe_serialization=False) | ||
tokenizer.save_pretrained(output_dir, safe_serialization=False) | ||
return quantized_model, tokenizer | ||
|
||
|
||
def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): | ||
"""Run the lm_eval command using subprocess.""" | ||
tasks_str = ",".join(tasks_list) | ||
command = [ | ||
"lm_eval", | ||
"--model", | ||
"hf", | ||
"--model_args", | ||
f"pretrained={model_dir}", | ||
"--tasks", | ||
f"{tasks_str}", | ||
"--device", | ||
f"{device}", | ||
"--batch_size", | ||
f"{batch_size}", | ||
] | ||
subprocess.run(command, check=True) | ||
|
||
|
||
def get_model_size_in_bytes(model, ignore_embeddings=False): | ||
""" | ||
Returns the model size in bytes. The option to ignore embeddings | ||
is useful for models with disproportionately large embeddings compared | ||
to other model parameters that get quantized/sparsified. | ||
""" | ||
|
||
def flat_size(tensor): | ||
if hasattr(tensor, "__tensor_flatten__"): | ||
size = 0 | ||
# 0th element is a list of attributes that | ||
# hold tensors | ||
for attr_name in tensor.__tensor_flatten__()[0]: | ||
sub_tensor = getattr(tensor, attr_name) | ||
size += flat_size(sub_tensor) | ||
return size | ||
else: | ||
return tensor.numel() * tensor.element_size() | ||
|
||
model_size = 0 | ||
for _, child in model.named_children(): | ||
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): | ||
for p in itertools.chain( | ||
child.parameters(recurse=False), child.buffers(recurse=False) | ||
): | ||
model_size += flat_size(p) | ||
model_size += get_model_size_in_bytes(child, ignore_embeddings) | ||
return model_size | ||
|
||
|
||
def run( | ||
model_id, | ||
quantization, | ||
tasks, | ||
device, | ||
batch_size, | ||
model_output_dir, | ||
): | ||
print(f"Running model {model_id} with quantization {quantization}") | ||
model_name = model_id.split("/")[-1] | ||
model_output_dir = f"quantized_model/{model_name}-{quantization}" | ||
quant_config = string_to_config(quantization, None) | ||
quantized_model, tokenizer = quantize_model_and_save( | ||
model_id, quant_config=quant_config, output_dir=model_output_dir | ||
) | ||
print("Compiling model ....") | ||
quantized_model = torch.compile( | ||
quantized_model, | ||
mode="reduce-overhead", | ||
fullgraph=True, | ||
) | ||
run_lm_eval( | ||
model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size | ||
) | ||
model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9 | ||
print(f"Model size: {model_size:.2f} GB") | ||
|
||
|
||
if __name__ == "__main__": | ||
if not _lm_eval_available: | ||
print( | ||
"lm_eval is required to run this script. Please install it using pip install lm-eval." | ||
) | ||
exit(0) | ||
|
||
# Set up argument parser | ||
parser = argparse.ArgumentParser( | ||
description="Quantize a model and evaluate its throughput." | ||
) | ||
parser.add_argument( | ||
"--model_id", | ||
type=str, | ||
default="meta-llama/Llama-3.1-8B", | ||
help="The model ID to use.", | ||
) | ||
parser.add_argument( | ||
"--quantization", | ||
type=str, | ||
default=None, | ||
help="The quantization method to use.", | ||
) | ||
parser.add_argument( | ||
"--tasks", | ||
nargs="+", | ||
type=str, | ||
default=["wikitext"], | ||
help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", | ||
) | ||
parser.add_argument( | ||
"--device", type=str, default="cuda:0", help="Device to run the model on." | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=1, help="Batch size for lm_eval." | ||
) | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="What are we having for dinner?", | ||
help="Prompt for model throughput evaluation.", | ||
) | ||
parser.add_argument( | ||
"--max_new_tokens", | ||
type=int, | ||
default=10, | ||
help="Max new tokens to generate for throughput evaluation.", | ||
) | ||
parser.add_argument( | ||
"--num_runs", | ||
type=int, | ||
default=5, | ||
help="Number of runs to average over for throughput evaluation.", | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
default="quantized_models", | ||
help="Output directory for quantized model.", | ||
) | ||
args = parser.parse_args() | ||
|
||
# Use parsed arguments | ||
run( | ||
model_id=args.model_id, | ||
quantization=args.quantization, | ||
tasks=args.tasks, | ||
device=args.device, | ||
batch_size=args.batch_size, | ||
model_output_dir=args.output_dir, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
# For llama3.1-8B | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128 --tasks wikitext hellaswag | ||
|
||
|
||
# For llama3.2-3B | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128 --tasks wikitext hellaswag |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,43 @@ | ||
## SAM2 | ||
# LLAMA | ||
|
||
## Eval on Llama 3.1 8B and Llama 3.2 3B | ||
|
||
We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below: | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would specify what task this is for. |
||
| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | | ||
|------------|---------------------------|-------|--------|----------------|-------------------| | ||
| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 15.01 | | ||
| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 4.76 | | ||
| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 8.04 | | ||
| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 8.03 | | ||
| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 8.03 | | ||
| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 8.04 | | ||
| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 8.03 | | ||
| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 4.76 | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per tensor is more accurate than per row quantization? |
||
|
||
| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) | | ||
|------------|---------------------------|-------|--------|----------------|-------------------| | ||
| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 6.43 | | ||
| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 2.29 | | ||
| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 3.61 | | ||
| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 3.61 | | ||
| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 3.61 | | ||
| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 3.61 | | ||
| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 3.61 | | ||
| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 2.29 | | ||
|
||
To generate the above results run: | ||
``` | ||
sh benchmarks/_models/eval_hf_models.sh | ||
``` | ||
|
||
To run lm-eval for a different hf-model with AO quantization technique, run: | ||
``` | ||
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag | ||
``` | ||
Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques. | ||
|
||
# SAM2 | ||
sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc | ||
|
||
It includes | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gemlite can be 4 bit or 8 bit, should probably specify that this is for 4 bit
https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L968