Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from typing import Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import disable_quantization
Expand Down Expand Up @@ -94,8 +93,6 @@ class AWQModifier(Modifier, QuantizationMixin):
- on_finalize
- clear resolved mappings and captured activations

:param sequential_targets: list of module names to compress in
the same calibration pass
:param mappings: list activation layers to smooth, and which layers to
scale the output such that activations are smoothed.
Each entry of the mapping list should be a list itself, in which the first
Expand All @@ -118,27 +115,30 @@ class AWQModifier(Modifier, QuantizationMixin):
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: Union[str, List[str], None] = None
mappings: Optional[List[AWQMapping]] = None
offload_device: Optional[torch.device] = None
mappings: list[AWQMapping] | None = None
offload_device: torch.device | None = None
duo_scaling: bool = True

# Private vars set during validation
_num_bits: Optional[int] = PrivateAttr(default=None)
_symmetric: Optional[bool] = PrivateAttr(default=None)
_group_size: Optional[int] = PrivateAttr(default=None)
_num_bits: int | None = PrivateAttr(default=None)
_symmetric: bool | None = PrivateAttr(default=None)
_group_size: int | None = PrivateAttr(default=None)

# Private vars set during initialization, cleared during finalization
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
_resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list)
# Cache list of forward input args for each parent module, one dict for each batch
_parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
_parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)
# Dict[smooth layer name, (activation means, activation counts)]
_smooth_activation_means: Dict[str, Tuple[torch.FloatTensor, int]] = PrivateAttr(
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
default_factory=dict
)

# NOTE: in case a user wants to run both AWQ and GPTQ before quantizing,
# this is set to True
_supports_disabling_quantization: bool = PrivateAttr(True)

# NOTE: different name chosen to avoid collision with
# QuantizationMixin.validate_model_after, which must be called first
@model_validator(mode="after")
Expand Down Expand Up @@ -389,7 +389,7 @@ def _setup_activation_cache_hooks(self) -> None:

def cache_parent_kwargs_hook(
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
args: tuple[torch.Tensor, ...],
kwargs,
):
values = inspect.signature(module.forward).bind(*args, **kwargs)
Expand All @@ -398,7 +398,7 @@ def cache_parent_kwargs_hook(
def create_cache_smooth_activations_hook_fn(smooth_name):
def cache_smooth_activations_hook(
_module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
args: tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
self._smooth_activation_means[smooth_name] = _accumulate_mean(
Expand Down Expand Up @@ -559,13 +559,13 @@ def _smooth(module):
v.batch_intermediates.clear()
self._assert_all_activations_consumed()

def _run_samples(self, module: Module) -> List[torch.Tensor]:
def _run_samples(self, module: Module) -> list[torch.Tensor]:
outputs = [
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
]
return [
# If Tuple, assume that first argument is the input
output[0] if isinstance(output, Tuple) else output
output[0] if isinstance(output, tuple) else output
for output in outputs
]

Expand All @@ -574,8 +574,8 @@ def _compute_best_scale(
x_mean: torch.Tensor,
w_mean: torch.Tensor,
parent_module: torch.nn.Module,
linears2scale: List[torch.nn.Linear],
fp16_outputs: List[torch.Tensor],
linears2scale: list[torch.nn.Linear],
fp16_outputs: list[torch.Tensor],
) -> torch.Tensor:
"""
Compute loss and select best scales
Expand Down Expand Up @@ -667,8 +667,8 @@ def _compute_best_scale(
@torch.no_grad()
def _compute_loss(
self,
fp16_outputs: List[torch.Tensor],
int_w_outputs: List[torch.Tensor],
fp16_outputs: list[torch.Tensor],
int_w_outputs: list[torch.Tensor],
device: torch.device,
) -> torch.Tensor:
loss = 0.0
Expand Down Expand Up @@ -746,8 +746,8 @@ def _pseudo_quantize_tensor(

def _accumulate_mean(
inp: torch.Tensor,
prev_mean_and_count: Optional[Tuple[torch.FloatTensor, int]],
) -> Tuple[torch.FloatTensor, int]:
prev_mean_and_count: tuple[torch.FloatTensor, int] | None,
) -> tuple[torch.FloatTensor, int]:
sum_added = inp.sum(dim=0)
num_added = inp.size(0)
if prev_mean_and_count is None:
Expand All @@ -761,7 +761,7 @@ def _accumulate_mean(
return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent.

Expand Down
13 changes: 6 additions & 7 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from loguru import logger
from torch.nn import Module
Expand Down Expand Up @@ -143,7 +142,7 @@ class AWQMapping:
# ["re:.*dense$"]
# ),
]
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
"BloomForCausalLM": _bloom_mappings,
"CohereForCausalLM": _cohere_mappings,
"Cohere2ForCausalLM": _cohere_mappings,
Expand Down Expand Up @@ -186,13 +185,13 @@ class ResolvedMapping:

smooth_name: str
smooth_layer: Module
balance_layers: List[Module]
balance_names: Optional[List[str]] = None
parent: Optional[Module] = None
parent_name: Optional[str] = None
balance_layers: list[Module]
balance_names: list[str]
parent: Module
parent_name: str


def get_layer_mappings_from_architecture(architecture: str) -> List[AWQMapping]:
def get_layer_mappings_from_architecture(architecture: str) -> list[AWQMapping]:
"""
:param architecture: str: The architecture of the model
:return: list: The layer mappings for the given architecture
Expand Down
38 changes: 25 additions & 13 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any

import torch
from compressed_tensors.modeling import (
Expand Down Expand Up @@ -89,29 +89,35 @@ class QuantizationMixin(HooksMixin):
and kv_cache_scheme != None, the quantization of kv cache will fail
"""

config_groups: Optional[Dict[str, QuantizationScheme]] = None
config_groups: dict[str, QuantizationScheme] | None = None
# NOTE: targets is not the sole source of truth for finding all matching target
# layers in a model. Additional information can be stored in `config_groups`
# Use self.resolved_targets as source of truth.
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
ignore: List[str] = Field(default_factory=list)
scheme: Optional[Union[str, Dict[str, Any]]] = None
kv_cache_scheme: Optional[QuantizationArgs] = None
targets: str | list[str] | None = Field(default_factory=lambda: ["Linear"])
ignore: list[str] = Field(default_factory=list)
scheme: str | dict[str, Any] | None = None
kv_cache_scheme: QuantizationArgs | None = None

_calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set)
_resolved_config: Optional[QuantizationConfig] = PrivateAttr(None)
_calibration_hooks: set[RemovableHandle] = PrivateAttr(default_factory=set)
_resolved_config: QuantizationConfig | None = PrivateAttr(None)

# NOTE: in some cases, we need to allow users to run instances of the
# QuantizationMixin without quantizing modules, e.g. when a user wants
# to run both AWQ and GPTQ before quantizing. Set this field to True
# on classes that subclass QuantiztaionMixin to allow for this.
_supports_disabling_quantization: bool = PrivateAttr(False)

@field_validator("targets", mode="before")
def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
def validate_targets(cls, value: str | list[str]) -> list[str]:
if isinstance(value, str):
return [value]

return value

@field_validator("scheme", mode="before")
def validate_scheme(
cls, value: Optional[Union[str, Dict[str, Any]]]
) -> Optional[Union[str, Dict[str, Any]]]:
cls, value: str | dict[str, Any] | None
) -> str | dict[str, Any] | None:
if isinstance(value, str) and not is_preset_scheme(value):
raise ValueError(
"`scheme` must either be a preset scheme name or a dictionary "
Expand All @@ -138,7 +144,7 @@ def resolved_config(self) -> QuantizationConfig:
return self._resolved_config

@property
def resolved_targets(self) -> Set[str]:
def resolved_targets(self) -> set[str]:
"""
Set of all resolved targets, i.e. all unique targets listed
in resolved quantization config.
Expand Down Expand Up @@ -221,6 +227,12 @@ def resolve_quantization_config(self) -> QuantizationConfig:
kv_cache_scheme = self.kv_cache_scheme
ignore = self.ignore

# NOTE: this will only happen if user explicitly sets targets=None
if targets is None and config_groups is None:
if self._supports_disabling_quantization:
return QuantizationConfig({})
raise ValueError("Please specify either `targets` or `config_groups`")

if scheme is not None and config_groups is not None:
raise ValueError("Please specify either `scheme` or `config_groups`")

Expand Down Expand Up @@ -286,7 +298,7 @@ def _initialize_observers(self, module: torch.nn.Module):
if output:
initialize_observer(module, base_name="output")

def _initialize_hooks(self, module: torch.nn.Module) -> Set[RemovableHandle]:
def _initialize_hooks(self, module: torch.nn.Module) -> set[RemovableHandle]:
hooks = set()
if not hasattr(module, "quantization_scheme"):
return hooks
Expand Down
6 changes: 6 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,9 @@ def test_get_lowest_common_parent():
["embed_tokens", "decoder.self_attn.v_proj"], model
)
assert parent_name == "" and parent == model


def test_awq_supports_disabling_quantization():
awq = AWQModifier(scheme="W4A16", targets=None)

assert len(awq.resolved_config.config_groups) == 0
10 changes: 9 additions & 1 deletion tests/llmcompressor/modifiers/quantization/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier


@pytest.fixture
Expand Down Expand Up @@ -211,3 +211,11 @@ def test_resolved_targets(
)

assert modifier.resolved_targets == resolved_targets


def test_does_not_support_disabling_quantization():
with pytest.raises(ValueError):
GPTQModifier(scheme="W4A16", targets=None).resolve_quantization_config()

with pytest.raises(ValueError):
QuantizationModifier(scheme="W4A16", targets=None).resolve_quantization_config()
Loading