Skip to content

Feature Branch for AWQ Modifier #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/awq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
720 changes: 720 additions & 0 deletions src/llmcompressor/modifiers/awq/base.py

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Utility / helper functions
"""

import functools
import gc
import inspect
import os
import random
import re
Expand Down Expand Up @@ -85,6 +88,11 @@
"detach",
"adjust_quantization_for_onnx_export",
"get_dependency_order",
"pseudo_quantize_tensor",
"pseudo_dequantize_linear",
"tensor_forward_with_input_args",
"sanitize_kwargs_for_module",
"clear_memory",
]


Expand Down Expand Up @@ -680,6 +688,45 @@ def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor:
return -1.0 * newly_masked + newly_unmasked


def sanitize_kwargs_for_module(
kwargs: Dict[str, Any], module: Module
) -> Dict[str, Any]:
"""
Sanitize the kwargs for a Module by removing any keys that are not
in the signature of the forward method.

:param kwargs: the kwargs to sanitize
:param module: the Module to sanitize the kwargs for
:return: the sanitized kwargs for the callable object
"""
if not isinstance(kwargs, dict):
raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}")

allowed_params = inspect.signature(module.forward).parameters
return {key: value for key, value in kwargs.items() if key in allowed_params}


def tensor_forward_with_input_args(
module: Module, inputs: Tensor, input_kwargs: Dict[str, Any]
) -> Tensor:
"""
Forward the given inputs through the given module with the given input_kwargs.
This function is a wrapper around tensors_module_forward that ensures that the
input_kwargs are sanitized and passed to the module as keyword arguments during
the forward pass.

:param module: the module to forward the inputs through
:param inputs: the inputs to forward through the module
:param input_kwargs: the keyword arguments to pass to the
module during the forward pass
:return: the output of the module after forwarding the inputs through it
"""
inputs = inputs.to(next(module.parameters()).device)
input_kwargs = sanitize_kwargs_for_module(input_kwargs, module)

return tensors_module_forward(inputs, functools.partial(module, **input_kwargs))


##############################
#
# pytorch module helper functions
Expand Down Expand Up @@ -1194,3 +1241,69 @@ def swap_modules(
parent.__setattr__(sections[-1], submodule_to_replace)

return cur


def pseudo_quantize_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add a comment as to what this function's purpose/function is?

w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
org_w_shape = w.shape
if group_size > 0:
assert org_w_shape[-1] % group_size == 0
w = w.reshape(-1, group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0

if not symmetric:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**bit_width - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (bit_width - 1) - 1
min_int = -(2 ** (bit_width - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales

assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)

return w, scales, zeros


def pseudo_dequantize_linear(
w: torch.Tensor,
scales: torch.Tensor,
zeros: Optional[torch.Tensor] = None,
symmetric: bool = False,
):
# get repeated count
repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)

# dequantize
if not symmetric:
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
w = (w.weight.data - zeros) * scales
else:
w = w.weight.data * scales

return w


def clear_memory(value: Optional[Any] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super not how the python garbage collector works

if value is not None:
del value
gc.collect()
torch.cuda.empty_cache()
1 change: 1 addition & 0 deletions src/llmcompressor/transformers/finetune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .evolcodealpaca import EvolCodeAlpacaDataset
from .gsm8k import GSM8KDataset
from .open_platypus import OpenPlatypusDataset
from .pile import PileEvalDataset
from .ptb import PtbDataset
from .ultrachat_200k import UltraChatDataset
from .wikitext import WikiTextDataset
47 changes: 47 additions & 0 deletions src/llmcompressor/transformers/finetune/data/pile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from copy import deepcopy
from typing import Optional

from llmcompressor.transformers.finetune.data import TextGenerationDataset


@TextGenerationDataset.register(name="pile_eval")
class PileEvalDataset(TextGenerationDataset):
"""
Child text generation class for the PileEval dataset

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
data_args = deepcopy(data_args)
data_args.dataset = "mit-han-lab/pile-val-backup"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
"""
Load the raw dataset from Hugging Face, using cached copy if available.
Additionally reformats the entries to fit the template.

:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
raw_dataset = super().get_raw_dataset(cache_dir=cache_dir)

def restructure_fn(sample):
sample["text"] = sample["text"].strip()
return sample

raw_dataset = self.map(
raw_dataset,
function=restructure_fn,
batched=False,
remove_columns=["meta"],
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
desc="Restructuring Pile Dataset",
)
return raw_dataset
21 changes: 21 additions & 0 deletions src/llmcompressor/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"get_layers_params",
"get_matching_layer",
"get_no_split_params",
"get_parent_by_name",
]


Expand Down Expand Up @@ -338,3 +339,23 @@ def get_no_split_params(module: Module) -> Union[str, List[str]]:
if hasattr(model, "_no_split_modules"):
return model._no_split_modules
return ALL_TARGET


def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]:
"""
Get the parent layer of a layer by name.

:param layer_name: Name of the layer to find the parent of.
:param model: Model to search for the parent layer.
:return: Tuple containing the name of the parent layer
and the parent layer itself.
"""
if not any(layer_name == name for name, _ in model.named_modules()):
raise ValueError(f"Layer '{layer_name}' not found in model")

parent_name_parts = layer_name.split(".")[:-1]
if not parent_name_parts:
return "", model

parent_name = ".".join(parent_name_parts)
return get_layer(parent_name, model)
Empty file.
28 changes: 28 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

import pytest

from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.factory import ModifierFactory
from tests.llmcompressor.modifiers.conf import setup_modifier_factory


@pytest.mark.unit
class TestAWQIsRegistered(unittest.TestCase):
def setUp(self):
self.kwargs = {}
setup_modifier_factory()

def test_awq_is_registered(self):
modifier = ModifierFactory.create(
type_="AWQModifier",
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

self.assertIsInstance(
modifier,
AWQModifier,
"PyTorch AWQModifier not registered",
)
42 changes: 42 additions & 0 deletions tests/llmcompressor/pytorch/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
get_optim_learning_rate,
mask_difference,
memory_aware_threshold,
sanitize_kwargs_for_module,
set_optim_learning_rate,
tensor_density,
tensor_export,
tensor_forward_with_input_args,
tensor_sample,
tensor_sparsity,
tensors_batch_size,
Expand Down Expand Up @@ -855,3 +857,43 @@ def test_memory_aware_threshold(tensor, idx):

if prior_state is not None:
os.environ[MEMORY_BOUNDED] = prior_state


class TestSanitizeKwargsForModule:
@pytest.fixture
def module(self):
return Linear(10, 20)

def test_sanitize_kwargs_for_module_not_dict(self, module):
# Test with kwargs that are not a dictionary
with pytest.raises(TypeError):
sanitize_kwargs_for_module("not a dictionary", module)

def test_sanitize_kwargs_for_module_not_in_signature(self, module):
# Test with kwargs that are not in the signature of the forward method
kwargs = {"not_in_signature": 123}
sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module)
assert sanitized_kwargs == {}

def test_sanitize_kwargs_for_module_in_signature(self, module):
# Test with kwargs that are in the signature of the forward method
kwargs = {"input": torch.randn(1, 10)}
sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module)
assert sanitized_kwargs == kwargs


class TestTensorForwardWithInputArgs:
@pytest.fixture
def module(self):
return Linear(10, 20)

def test_tensor_forward_with_input_args(self, module):
# Test with valid inputs and input_kwargs
inputs = torch.randn(1, 10)
input_kwargs = {}
output = tensor_forward_with_input_args(module, inputs, input_kwargs)
assert output.shape == (1, 20)

# Test with input_kwargs that are not in the signature of the forward method
input_kwargs = {"not_in_signature": 123}
tensor_forward_with_input_args(module, inputs, input_kwargs)
17 changes: 17 additions & 0 deletions tests/llmcompressor/transformers/finetune/data/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from llmcompressor.transformers.finetune.data import (
C4Dataset,
OpenPlatypusDataset,
PileEvalDataset,
TextGenerationDataset,
WikiTextDataset,
)
Expand Down Expand Up @@ -57,3 +58,19 @@ def test_open_platypus_initializes(tiny_llama_tokenizer):
assert op_manager.text_column == "text"
assert not op_manager.padding
assert op_manager.max_seq_length == data_args.max_seq_length


@pytest.mark.usefixtures("tiny_llama_tokenizer")
def test_pile_eval_initializes(tiny_llama_tokenizer):
data_args = DataTrainingArguments(dataset="pile_eval", pad_to_max_length=False)
pile_eval_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split=None,
tokenizer=tiny_llama_tokenizer,
)
assert isinstance(pile_eval_manager, TextGenerationDataset)
assert isinstance(pile_eval_manager, PileEvalDataset)
assert pile_eval_manager.text_column == "text"
assert not pile_eval_manager.padding
assert pile_eval_manager.max_seq_length == data_args.max_seq_length
Empty file.
31 changes: 31 additions & 0 deletions tests/llmcompressor/utils/pytorch/test_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

import torch.nn as nn

from llmcompressor.utils.pytorch import get_parent_by_name


class TestGetParentByName(unittest.TestCase):
def setUp(self):
self.model = nn.Sequential(
nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.Softmax(dim=1)
)

def test_get_parent_by_name(self):
# Test getting the parent of a non-existent layer
with self.assertRaises(ValueError):
get_parent_by_name("non_existent_layer", self.model)

# Test getting the parent of the first layer
name, parent = get_parent_by_name("0", self.model)
self.assertEqual(parent, self.model)

# Test getting the parent of a nested layer
nested_model = nn.Sequential(
nn.Linear(10, 20),
nn.Sequential(nn.ReLU(), nn.Linear(20, 10)),
nn.Softmax(dim=1),
)
name, parent = get_parent_by_name("1.1", nested_model)
self.assertEqual(parent, nested_model[1])
self.assertEqual(name, "1")
Loading