-
Notifications
You must be signed in to change notification settings - Fork 294
mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
af83deb
milestone1: naive_intNwo + eval/benchmark
Hanxian97 02ef81b
remove experiment scripts
Hanxian97 cf2c134
remove exp files
Hanxian97 1055f14
use default ZeroPointDomain.INT for int2/3/5/6
Hanxian97 c00b16d
renamed test_naive_intNwo.py to test_mixed_precision.py
Hanxian97 f765eef
updated intNwo with _get_linear_subclass_inserter
Hanxian97 9a343a4
adjust sqnr threshold according to bit width
Hanxian97 aafe38e
fixed test for int4wo and add __init__.py
Hanxian97 1bfa370
skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on ni…
Hanxian97 f4fccf3
edit the sqnr threshold
Hanxian97 8e787b6
add unittest
Hanxian97 e516f0b
correct import path
Hanxian97 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only | ||
from torchao.quantization.utils import compute_error | ||
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only | ||
|
||
_CUDA_IS_AVAILABLE = torch.cuda.is_available() | ||
|
||
class TestWeightOnlyQuantNaive(unittest.TestCase): | ||
|
||
def test_quantization_intNwo(self): | ||
#skip test int4wo for now since it is under development in torchao | ||
for quantization_bit in [2, 3, 5, 6, 8]: | ||
for symmetric in [False, True]: | ||
with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric): | ||
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: | ||
x = torch.randn(*x_shape, dtype=torch.bfloat16) | ||
m = nn.Sequential(nn.Linear(32, 80)).bfloat16() | ||
y_ref = m(x) | ||
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) | ||
y_wo = m(x) | ||
sqnr = compute_error(y_ref, y_wo) | ||
# SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization | ||
# e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills | ||
expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02 | ||
self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Empty file.
1 change: 1 addition & 0 deletions
1
torchao/quantization/prototype/mixed_precision/scripts/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .naive_intNwo import intN_weight_only |
95 changes: 95 additions & 0 deletions
95
torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from naive_intNwo import intN_weight_only | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from lm_eval.models.huggingface import HFLM | ||
from lm_eval.evaluator import evaluate | ||
from lm_eval.tasks import get_task_dict | ||
|
||
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight | ||
from torchao._models._eval import TransformerEvalWrapper | ||
|
||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from torchao.quantization.quant_api import autoquant | ||
|
||
|
||
torch._inductor.config.force_fuse_int_mm_with_mul = True | ||
torch._inductor.config.fx_graph_cache = True | ||
|
||
|
||
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size): | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(repo_id) | ||
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) | ||
|
||
if quantization == "autoquant": | ||
model = autoquant(model.to(device=device)) | ||
|
||
# naive implementation of uniform precision quantization all layers | ||
elif quantization in ["2","3","4","5","6","8"]: | ||
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym)) | ||
|
||
# mix precision quantization for Llama3 | ||
elif quantization == "MP_llama3": | ||
|
||
# filter for sensitive layers (the first 3 and last 2 layers for Llama3) | ||
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: | ||
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) | ||
|
||
# filter for non-sensitive layers (other 27 layers for Llama3) | ||
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: | ||
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) | ||
|
||
# quantize the sensitive layers | ||
if sensi_bit != 16: | ||
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen) | ||
|
||
# quantize the less-sensitive layers | ||
if sensi_bit == 4: | ||
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) | ||
else: | ||
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) | ||
|
||
if compile: | ||
model = torch.compile(model, mode="max-autotune", fullgraph=True) | ||
|
||
with torch.no_grad(): | ||
|
||
result = evaluate( | ||
HFLM( | ||
pretrained=model, | ||
tokenizer=tokenizer, | ||
batch_size=batch_size, | ||
max_length=max_length), | ||
get_task_dict(tasks), | ||
limit = limit, | ||
) | ||
|
||
for task, res in result["results"].items(): | ||
print(f"{task}: {res}") | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Run HF Model Evaluation') | ||
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') | ||
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') | ||
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') | ||
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') | ||
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') | ||
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization') | ||
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') | ||
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') | ||
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') | ||
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers') | ||
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers') | ||
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default') | ||
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on') | ||
args = parser.parse_args() | ||
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) |
60 changes: 60 additions & 0 deletions
60
torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
|
||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from torchao.quantization import int8_weight_only, int4_weight_only | ||
from torchao.quantization.quant_api import _get_linear_subclass_inserter | ||
|
||
def intN_weight_only(group_size=32, n=8, symmetric=False): | ||
''' | ||
Apply int N-bit weight only quantization to a linear layer. | ||
Args: | ||
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] | ||
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] | ||
Usage: | ||
from torchao.quantization import quantize_ | ||
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize) | ||
''' | ||
# for asymmetric quantization | ||
def apply_intN_weight_only_quant_asym(weight): | ||
# avoid circular dependency | ||
from torchao.dtypes import to_affine_quantized | ||
mapping_type = MappingType.ASYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.uint8 | ||
quant_min = 0 | ||
quant_max = 2**n-1 | ||
eps = 1e-6 | ||
preserve_zero = True | ||
zero_point_dtype = torch.int64 | ||
zero_point_domain = ZeroPointDomain.INT | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) | ||
|
||
# for symmetric quantization | ||
def apply_intN_weight_only_quant_sym(weight): | ||
# avoid circular dependency | ||
from torchao.dtypes import to_affine_quantized | ||
mapping_type = MappingType.SYMMETRIC | ||
block_size = (1, group_size) | ||
target_dtype = torch.int8 | ||
eps = 1e-6 | ||
zero_point_dtype = torch.int64 | ||
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) | ||
|
||
try: | ||
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" | ||
if n == 8: | ||
return int8_weight_only() | ||
elif n == 4: | ||
return int4_weight_only(group_size=group_size) | ||
else: | ||
if symmetric: | ||
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym) | ||
else: | ||
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym) | ||
except Exception as e: | ||
raise | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.