Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def tokenize(sample):
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
scheme="W4A16",
targets=["Linear"],
use_auto_awq_mem_hack=False, # GPU VRAM 37784MiB
# use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB
),
]

Expand Down
77 changes: 60 additions & 17 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Any

import torch
from compressed_tensors.quantization import disable_quantization
Expand Down Expand Up @@ -122,6 +122,7 @@ class AWQModifier(Modifier, QuantizationMixin):
mappings: Optional[List[AWQMapping]] = None
offload_device: Optional[torch.device] = None
duo_scaling: bool = True
use_auto_awq_mem_hack: bool = True

# Private vars set during validation
_num_bits: Optional[int] = PrivateAttr(default=None)
Expand All @@ -130,8 +131,10 @@ class AWQModifier(Modifier, QuantizationMixin):

# Private vars set during initialization, cleared during finalization
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
# Cache list of forward input args for each parent module, one dict for each batch
_parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr(
# Model-wise cache of kwargs for all parent modules
_model_kwargs_cache: IntermediatesCache = PrivateAttr()
# Cache of forward hidden states for each parent module, one tensor for each batch
_parent_kwargs_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)
# Dict[smooth layer name, (activation means, activation counts)]
Expand Down Expand Up @@ -290,7 +293,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:
if not self.ended_:
self.on_end(state, None)

self._parent_args_cache.clear()
self._parent_kwargs_cache.clear()
self._model_kwargs_cache = None
self._smooth_activation_means.clear()
self._resolved_mappings.clear()

Expand Down Expand Up @@ -387,13 +391,37 @@ def _setup_activation_cache_hooks(self) -> None:
calculate the dynamic range during calibration
"""

def cache_parent_kwargs_hook(
def cache_hidden_states_kwargs_hook(
module: torch.nn.Module,
args: Tuple[torch.Tensor, ...],
kwargs,
):
batch_idx = len(self._parent_kwargs_cache[module])

values = inspect.signature(module.forward).bind(*args, **kwargs)
self._parent_args_cache[module].append(values.arguments)

# our original impl: all kwargs are cached for each parent
# technically correct way, but probably lots of redundancy
if not self.use_auto_awq_mem_hack:
self._parent_kwargs_cache[module].append(values.arguments)
return

# autoawq impl: only first param is cached for each parent
# all others are pulled from model-wide cache
# much more memory efficient, but possibly incorrect
# depending on model definition
first_param_name, first_arg = next(iter(values.arguments.items()))

self._parent_kwargs_cache[module].append({first_param_name: first_arg})

values.arguments.pop(first_param_name)

if len(self._model_kwargs_cache) < batch_idx:
raise ValueError("THIS SHOULDNT HAPPEN")
elif len(self._model_kwargs_cache) == batch_idx:
self._model_kwargs_cache.append(values.arguments)
else:
self._model_kwargs_cache.update(batch_idx, values.arguments)

def create_cache_smooth_activations_hook_fn(smooth_name):
def cache_smooth_activations_hook(
Expand All @@ -409,17 +437,19 @@ def cache_smooth_activations_hook(

return cache_smooth_activations_hook

# Don't offload this, it will be used consistently
self._model_kwargs_cache = IntermediatesCache(None, None)
for mapping in self._resolved_mappings:
# parent kwargs needed for future forward passes
# same parent may appear multiple times in resolved mappings
if mapping.parent not in self._parent_args_cache:
self._parent_args_cache[mapping.parent] = IntermediatesCache(
if mapping.parent not in self._parent_kwargs_cache:
self._parent_kwargs_cache[mapping.parent] = IntermediatesCache(
None,
self.offload_device,
)
self.register_hook(
mapping.parent,
cache_parent_kwargs_hook,
cache_hidden_states_kwargs_hook,
"forward_pre",
with_kwargs=True,
)
Expand Down Expand Up @@ -555,19 +585,28 @@ def _smooth(module):
# remove caches needed to smooth this mapping
del self._smooth_activation_means[mapping.smooth_name]

for v in self._parent_args_cache.values():
for v in self._parent_kwargs_cache.values():
v.batch_intermediates.clear()
self._assert_all_activations_consumed()

def _run_samples(self, module: Module) -> List[torch.Tensor]:
outputs = [
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
]
return [
outputs = []
parameter_keys = inspect.signature(module.forward).parameters.keys()

for batch_idx in range(len(self._parent_kwargs_cache[module])):
batch_kwargs = self._model_kwargs_cache.fetch(
batch_idx, ignore_missing=True
)
batch_kwargs.update(self._parent_kwargs_cache[module].fetch(batch_idx))
batch_kwargs = {
k: v for k, v in batch_kwargs.items() if k in parameter_keys
}

output = module(**batch_kwargs)
# If Tuple, assume that first argument is the input
output[0] if isinstance(output, Tuple) else output
for output in outputs
]
outputs.append(output[0] if isinstance(output, Tuple) else output)

return outputs

def _compute_best_scale(
self,
Expand All @@ -592,6 +631,10 @@ def _compute_best_scale(
best_scales = None
best_error = float("inf")

# NOTE: this changes the module pointers, so it invalidates
# field `_parent_kwargs_cache: dict[Module, IntermediatesCache]``
# parent_module = torch.compile(parent_module)

org_sd = {
k: v.cpu()
for k, v in parent_module.state_dict().items()
Expand Down
10 changes: 9 additions & 1 deletion src/llmcompressor/pipelines/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,23 @@ def from_dataloader(
return cls(batch_intermediates, offload_device)

def fetch(
self, batch_index: int, input_names: Optional[List[str]] = None
self,
batch_index: int,
input_names: Optional[List[str]] = None,
ignore_missing: bool = False,
) -> Dict[str, Any]:
"""
Fetch values belonging to a batch

:param batch_index: index of batch whose values are being fetched
:param input_names: list of keys whose values are being fetched
:ignore_missing: if an intermediate for batch_index is not found,
return an empty dict if this is True, otherwise an Out of Index
error will be raised.
:return: dictionary mapping keys to onloaded values
"""
if ignore_missing and batch_index >= len(self.batch_intermediates):
return {}
intermediates = self.batch_intermediates[batch_index]

return {
Expand Down
Loading