-
Notifications
You must be signed in to change notification settings - Fork 127
Feature Branch for AWQ Modifier #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f69a2e7
Empty commit for feature branch
rahul-tuli 62bdb6b
Add pile eval dataset (#179)
rahul-tuli f5db0a1
Add PileEvalDataset to data/__init__.py and create pile.py dataset mo…
rahul-tuli 7a112d6
Some cleanup
rahul-tuli ff92f58
Add get_parent_by_name to src/llmcompressor/utils/pytorch/module.py
rahul-tuli 6e7d221
Add tensor_forward_with_input_args and sanitize_kwargs_for_module fun…
rahul-tuli efbad39
Add PileEvalDataset to data/__init__.py and create pile.py dataset mo…
rahul-tuli b55a58d
Some cleanup
rahul-tuli ce1b3ec
Cleanup
rahul-tuli 6f7e316
fix condition when self.end is 0
rahul-tuli 488b6c9
Add: Weight clipping to AWQModifier (#184)
rahul-tuli 2037c89
Update src/llmcompressor/pytorch/utils/helpers.py
brian-dellabetta 942d2c7
formatting
brian-dellabetta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .base import * |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,9 @@ | |
Utility / helper functions | ||
""" | ||
|
||
import functools | ||
import gc | ||
import inspect | ||
import os | ||
import random | ||
import re | ||
|
@@ -85,6 +88,11 @@ | |
"detach", | ||
"adjust_quantization_for_onnx_export", | ||
"get_dependency_order", | ||
"pseudo_quantize_tensor", | ||
"pseudo_dequantize_linear", | ||
"tensor_forward_with_input_args", | ||
"sanitize_kwargs_for_module", | ||
"clear_memory", | ||
] | ||
|
||
|
||
|
@@ -680,6 +688,45 @@ def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor: | |
return -1.0 * newly_masked + newly_unmasked | ||
|
||
|
||
def sanitize_kwargs_for_module( | ||
kwargs: Dict[str, Any], module: Module | ||
) -> Dict[str, Any]: | ||
""" | ||
Sanitize the kwargs for a Module by removing any keys that are not | ||
in the signature of the forward method. | ||
|
||
:param kwargs: the kwargs to sanitize | ||
:param module: the Module to sanitize the kwargs for | ||
:return: the sanitized kwargs for the callable object | ||
""" | ||
if not isinstance(kwargs, dict): | ||
raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}") | ||
|
||
allowed_params = inspect.signature(module.forward).parameters | ||
return {key: value for key, value in kwargs.items() if key in allowed_params} | ||
|
||
|
||
def tensor_forward_with_input_args( | ||
module: Module, inputs: Tensor, input_kwargs: Dict[str, Any] | ||
) -> Tensor: | ||
""" | ||
Forward the given inputs through the given module with the given input_kwargs. | ||
This function is a wrapper around tensors_module_forward that ensures that the | ||
input_kwargs are sanitized and passed to the module as keyword arguments during | ||
the forward pass. | ||
|
||
:param module: the module to forward the inputs through | ||
:param inputs: the inputs to forward through the module | ||
:param input_kwargs: the keyword arguments to pass to the | ||
module during the forward pass | ||
:return: the output of the module after forwarding the inputs through it | ||
""" | ||
inputs = inputs.to(next(module.parameters()).device) | ||
input_kwargs = sanitize_kwargs_for_module(input_kwargs, module) | ||
|
||
return tensors_module_forward(inputs, functools.partial(module, **input_kwargs)) | ||
|
||
|
||
############################## | ||
# | ||
# pytorch module helper functions | ||
|
@@ -1194,3 +1241,69 @@ def swap_modules( | |
parent.__setattr__(sections[-1], submodule_to_replace) | ||
|
||
return cur | ||
|
||
|
||
def pseudo_quantize_tensor( | ||
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 | ||
): | ||
org_w_shape = w.shape | ||
if group_size > 0: | ||
assert org_w_shape[-1] % group_size == 0 | ||
w = w.reshape(-1, group_size) | ||
assert w.dim() == 2 | ||
assert torch.isnan(w).sum() == 0 | ||
|
||
if not symmetric: | ||
max_val = w.amax(dim=1, keepdim=True) | ||
min_val = w.amin(dim=1, keepdim=True) | ||
max_int = 2**bit_width - 1 | ||
min_int = 0 | ||
scales = (max_val - min_val).clamp(min=1e-5) / max_int | ||
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) | ||
w = ( | ||
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros | ||
) * scales | ||
zeros = zeros.view(org_w_shape[0], -1) | ||
else: | ||
max_val = w.abs().amax(dim=1, keepdim=True) | ||
max_val = max_val.clamp(min=1e-5) | ||
max_int = 2 ** (bit_width - 1) - 1 | ||
min_int = -(2 ** (bit_width - 1)) | ||
scales = max_val / max_int | ||
zeros = None | ||
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales | ||
|
||
assert torch.isnan(scales).sum() == 0 | ||
assert torch.isnan(w).sum() == 0 | ||
|
||
scales = scales.view(org_w_shape[0], -1) | ||
w = w.reshape(org_w_shape) | ||
|
||
return w, scales, zeros | ||
|
||
|
||
def pseudo_dequantize_linear( | ||
w: torch.Tensor, | ||
scales: torch.Tensor, | ||
zeros: Optional[torch.Tensor] = None, | ||
symmetric: bool = False, | ||
): | ||
# get repeated count | ||
repeat_count = w.weight.data.shape[-1] // scales.shape[-1] | ||
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) | ||
|
||
# dequantize | ||
if not symmetric: | ||
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) | ||
w = (w.weight.data - zeros) * scales | ||
else: | ||
w = w.weight.data * scales | ||
|
||
return w | ||
|
||
|
||
def clear_memory(value: Optional[Any] = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is super not how the python garbage collector works |
||
if value is not None: | ||
del value | ||
gc.collect() | ||
torch.cuda.empty_cache() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from copy import deepcopy | ||
from typing import Optional | ||
|
||
from llmcompressor.transformers.finetune.data import TextGenerationDataset | ||
|
||
|
||
@TextGenerationDataset.register(name="pile_eval") | ||
class PileEvalDataset(TextGenerationDataset): | ||
""" | ||
Child text generation class for the PileEval dataset | ||
|
||
:param data_args: configuration settings for dataset loading | ||
:param split: split from dataset to load, for instance `test` or `train[:5%]` | ||
:param tokenizer: tokenizer to use on dataset | ||
""" | ||
|
||
def __init__(self, data_args, split, tokenizer): | ||
data_args = deepcopy(data_args) | ||
data_args.dataset = "mit-han-lab/pile-val-backup" | ||
super().__init__( | ||
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer | ||
) | ||
|
||
def get_raw_dataset(self, cache_dir: Optional[str] = None): | ||
""" | ||
Load the raw dataset from Hugging Face, using cached copy if available. | ||
Additionally reformats the entries to fit the template. | ||
|
||
:param cache_dir: disk location to search for cached dataset | ||
:return: the requested dataset | ||
""" | ||
raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) | ||
|
||
def restructure_fn(sample): | ||
sample["text"] = sample["text"].strip() | ||
return sample | ||
|
||
raw_dataset = self.map( | ||
raw_dataset, | ||
function=restructure_fn, | ||
batched=False, | ||
remove_columns=["meta"], | ||
num_proc=self.data_args.preprocessing_num_workers, | ||
load_from_cache_file=not self.data_args.overwrite_cache, | ||
desc="Restructuring Pile Dataset", | ||
) | ||
return raw_dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import unittest | ||
|
||
import pytest | ||
|
||
from llmcompressor.modifiers.awq import AWQModifier | ||
from llmcompressor.modifiers.factory import ModifierFactory | ||
from tests.llmcompressor.modifiers.conf import setup_modifier_factory | ||
|
||
|
||
@pytest.mark.unit | ||
class TestAWQIsRegistered(unittest.TestCase): | ||
def setUp(self): | ||
self.kwargs = {} | ||
setup_modifier_factory() | ||
|
||
def test_awq_is_registered(self): | ||
modifier = ModifierFactory.create( | ||
type_="AWQModifier", | ||
allow_experimental=False, | ||
allow_registered=True, | ||
**self.kwargs, | ||
) | ||
|
||
self.assertIsInstance( | ||
modifier, | ||
AWQModifier, | ||
"PyTorch AWQModifier not registered", | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import unittest | ||
|
||
import torch.nn as nn | ||
|
||
from llmcompressor.utils.pytorch import get_parent_by_name | ||
|
||
|
||
class TestGetParentByName(unittest.TestCase): | ||
def setUp(self): | ||
self.model = nn.Sequential( | ||
nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.Softmax(dim=1) | ||
) | ||
|
||
def test_get_parent_by_name(self): | ||
# Test getting the parent of a non-existent layer | ||
with self.assertRaises(ValueError): | ||
get_parent_by_name("non_existent_layer", self.model) | ||
|
||
# Test getting the parent of the first layer | ||
name, parent = get_parent_by_name("0", self.model) | ||
self.assertEqual(parent, self.model) | ||
|
||
# Test getting the parent of a nested layer | ||
nested_model = nn.Sequential( | ||
nn.Linear(10, 20), | ||
nn.Sequential(nn.ReLU(), nn.Linear(20, 10)), | ||
nn.Softmax(dim=1), | ||
) | ||
name, parent = get_parent_by_name("1.1", nested_model) | ||
self.assertEqual(parent, nested_model[1]) | ||
self.assertEqual(name, "1") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please add a comment as to what this function's purpose/function is?