Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 124 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
initialize_module_for_quantization,
is_attention_module,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import (
QuantizationConfig,
QuantizationStatus,
Expand Down Expand Up @@ -110,7 +113,10 @@ def load_pretrained_quantization_parameters(


def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
model: Module,
config: Union[QuantizationConfig, None],
run_compressed: bool = False,
validate_group_or_block_size: bool = True,
):
"""
Initializes the model for quantization in-place based on the given config.
Expand All @@ -120,6 +126,8 @@ def apply_quantization_config(
:param config: quantization config
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
:param validate_group_or_block_size: if True, validates that weight dimensions are
evenly divisible by group_size or block_structure. Defaults to True.
"""
from compressed_tensors.linear.compressed_linear import CompressedLinear

Expand Down Expand Up @@ -182,6 +190,11 @@ def apply_quantization_config(

submodule.quantization_status = config.quantization_status

# Validate group/block size divisibility if enabled
if validate_group_or_block_size:
match_generator = match_named_modules(model, target_to_scheme, config.ignore)
_validate_group_or_block_size(match_generator)


def _apply_kv_cache_scheme(
model: torch.nn.Module,
Expand Down Expand Up @@ -258,3 +271,112 @@ def _scheme_from_targets(
# return the first scheme (the prioritized one,
# since the order of target_to_scheme matters)
return target_to_scheme[targets[0]]


def _validate_group_or_block_size(modules: list[tuple[str, Module]]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type is incorrect for match_named_modules output

Suggested change
def _validate_group_or_block_size(modules: list[tuple[str, Module]]) -> None:
def _validate_group_or_block_size(modules: Generator[tuple[str, Module]]) -> None:

"""
Validates quantization parameter divisibility for all modules:
- GROUP strategy: weight columns must be evenly divisible by group_size
- BLOCK strategy: weight dimensions must be evenly divisible by block_structure

Raises a ValueError if validation fails, providing a comprehensive error
message with suggested fixes.

:param modules: List of (fqn, module) tuples to validate
:raises ValueError: If any module has dimension divisibility issues
"""
problematic_layers = []

for fqn, module in modules:
issue = _check_module_divisibility(fqn, module)
if issue is not None:
problematic_layers.append(fqn)

if problematic_layers:
error_msg = _generate_divisibility_error_message(problematic_layers)
raise ValueError(error_msg)


def _check_module_divisibility(fqn: str, module: Module) -> Optional[str]:
"""
Checks a single module for group size divisibility (GROUP strategy) and
block structure divisibility (BLOCK strategy).

:param fqn: Fully qualified name of the module
:param module: Module to check
:return: fqn if there's an issue, None otherwise
"""
quant_scheme = getattr(module, "quantization_scheme", None)
if quant_scheme is None:
return None

quant_args = quant_scheme.weights
if quant_args is None:
return None

# Check if module has weight
if not hasattr(module, "weight"):
return None

weight = module.weight

# Validate for GROUP strategy
if quant_args.strategy == QuantizationStrategy.GROUP:
group_size = quant_args.group_size
if group_size is None:
return None

# Get number of columns based on module type
if isinstance(module, torch.nn.Conv2d):
num_columns = weight.shape[1]
else:
num_columns = weight.shape[-1]

# Check divisibility
if num_columns % group_size != 0:
return fqn

# Validate for BLOCK strategy
elif quant_args.strategy == QuantizationStrategy.BLOCK:
block_structure = quant_args.block_structure
if block_structure is None:
return None

block_height, block_width = block_structure

if isinstance(module, torch.nn.Conv2d):
num_rows, num_columns = weight.shape[:2]
else:
num_rows, num_columns = weight.shape[-2:]

# Check divisibility for both dimensions
if num_rows % block_height != 0 or num_columns % block_width != 0:
return fqn

return None


def _generate_divisibility_error_message(problematic_layers: List[str]) -> str:
"""
Generate error message for quantization divisibility validation failures.

:param problematic_layers: List of layer names with divisibility issues
:return: Formatted error message
"""
header = "ERROR: Quantization divisibility validation failed!\n"
description = (
"Found layers with weight dimensions that are not evenly divisible\n"
"by the specified group_size or block_structure.\n\n"
)

error_msg = "\n" + "=" * 80 + "\n" + header + "=" * 80 + "\n\n" + description

error_msg += "\n" + "-" * 80 + "\n"
error_msg += "SUGGESTED FIX: Add the following to your quantization config:\n\n"
error_msg += f" ignore: {problematic_layers}\n"
error_msg += "\n" + "-" * 80 + "\n"
error_msg += "\nNote: any modules like lm_head which are ignored by default"
error_msg += "also need to be manually ignored"
error_msg += "=" * 80 + "\n"

return error_msg
Loading
Loading