Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llmcompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
create_session,
reset_session,
)
from llmcompressor.entrypoints import Oneshot, oneshot, train
from llmcompressor.entrypoints import Oneshot, oneshot, train, model_free_ptq
1 change: 1 addition & 0 deletions src/llmcompressor/entrypoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@

from .oneshot import Oneshot, oneshot
from .train import train
from .model_free import model_free_ptq
from .utils import post_process, pre_process
138 changes: 138 additions & 0 deletions src/llmcompressor/entrypoints/model_free/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional

import torch
import tqdm
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils.match import _match_name
from loguru import logger
from safetensors.torch import load_file, save_file

from llmcompressor.entrypoints.model_free.helpers import (
gpu_if_available,
validate_scheme,
)
from llmcompressor.entrypoints.model_free.lifecycle import (
calibrate_weights,
compress_module,
initialize_quantized_linear,
)
from llmcompressor.entrypoints.model_free.model_utils import (
get_checkpoint_files,
is_weights_file,
)
from llmcompressor.entrypoints.model_free.save_utils import (
update_config,
update_safetensors_index,
)

__all__ = ["model_free_ptq"]


def model_free_ptq(
model_stub: str | os.PathLike,
save_directory: str | os.PathLike,
scheme: QuantizationScheme | str,
ignore: Optional[list[str]] = None,
max_workers: int = 1,
device: Optional[torch.device | str] = None,
):
"""
Quantize a model without the need for a model definition. This function operates on
a model stub or folder containing weights saved in safetensors files

:param model_stub: huggingface model hub or path to local weights files
:param scheme: weight quantization scheme or preset scheme name
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param max_workers: number of worker threads to process files with
:param device: gpu device to accelerate quantization with
"""
# validate arguments
model_files = get_checkpoint_files(model_stub)
scheme_name, scheme = validate_scheme(scheme)
device = gpu_if_available(device)

# 0. collect safetensors files, copy files
jobs = []
for file_path, resolved_path in model_files:
save_path = Path(save_directory) / file_path

if file_path.endswith("safetensors"):
jobs.append(
(_process_file, resolved_path, save_path, scheme, ignore, device)
)

else:
if is_weights_file(file_path):
logger.warning(f"Skipping weights file {file_path}")
save_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Copying {file_path} {save_path}")
shutil.copyfile(resolved_path, save_path)

# 1-4. quantize and compress weights
with ThreadPoolExecutor(max_workers) as executor:
futures = [executor.submit(*job) for job in jobs]

total_size = 0
weight_map = dict()
for future in tqdm.tqdm(
as_completed(futures), total=len(futures), desc="Quantizing"
):
_total_size, _weight_map = future.result()
total_size += _total_size
weight_map.update(_weight_map)

# 5. update config and safetensors index
update_config(save_directory, scheme_name, scheme, ignore)
update_safetensors_index(save_directory, total_size, weight_map)


def _process_file(
file_path: str | os.PathLike,
save_path: str | os.PathLike,
scheme: QuantizationScheme,
ignore: str | list[str],
device: str | torch.device,
) -> tuple[int, dict[str, str]]:
"""
Quantize and compress tensors in a given safetensors file

:param file_path: safetensors file to process
:param save_path: save path of file with quantized weights
:param scheme: quantization scheme to apply to tensors
:param ignore: modules to ignore. Modules ending with "norm" are automatically
ignored
:param device: device used to quantize and compress weights
"""
tensors = load_file(file_path)

for name in list(tensors.keys()):
module_name, param_name = name.rsplit(".", 1)
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
if not is_linear_weight or is_ignored:
continue

# 1. initialize module with qparams (on device)
module = initialize_quantized_linear(tensors[name], scheme, device)

# 2. calibrate weight qparams
calibrate_weights(module)

# 3. compress module using qparams
compress_module(module)

# 4. save compressed data (on cpu)
del tensors[name]
prefix = module_name + "."
for key, value in module.state_dict(prefix=prefix).items():
tensors[key] = value.to("cpu")

save_file(tensors, save_path)
total_size = sum(tensor.nbytes for tensor in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map
75 changes: 75 additions & 0 deletions src/llmcompressor/entrypoints/model_free/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Optional

import torch
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.match import _match_name
from loguru import logger

__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]


def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
# treat strings as preset schemes
if isinstance(scheme, str):
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
else:
scheme_name = "config_group_0"

# weight quantization must be provided
if scheme.weights is None:
raise ValueError(
"Must provide a weights quanitization scheme to perform weights-only PTQ"
)

# activation quantization must be dynamic
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
if input_dynamic is not True or output_dynamic is not True:
raise ValueError(
"Model Free PTQ cannot calibrate activations. "
"Please use `oneshot` instead."
)

# override with static observers
# Remove after https://github.com/vllm-project/compressed-tensors/pull/489
if scheme.weights.observer in ("minmax", "mse"):
new_observer = f"static_{scheme.weights.observer}"
logger.warning(
f"Scheme uses {scheme.weights.observer} weight observer. "
f"Using {new_observer} instead"
)
scheme.weights.observer = new_observer

# target all modules; filter by ignore list
# technically this should be "re:.*", but vllm's
# ct moe layer has a hard coded check for "Linear"
scheme.targets = ["Linear"]
return scheme_name, scheme


def gpu_if_available(device: torch.device | str | None) -> torch.device:
if device is not None:
return torch.device(device)

elif torch.cuda.is_available():
return torch.device("cuda:0")

elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu:0")

else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
return torch.device("cpu")


def is_match_name(
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
) -> bool:
targets = targets if isinstance(targets, list) else [targets]
ignore = ignore if isinstance(ignore, list) else [ignore]

matches_target = any(_match_name(name, target) for target in targets)
matches_ignore = any(_match_name(name, ign) for ign in ignore)

return matches_target and not matches_ignore
73 changes: 73 additions & 0 deletions src/llmcompressor/entrypoints/model_free/lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from compressed_tensors.compressors import BaseCompressor
from compressed_tensors.config.format import _get_quant_compression_format
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStrategy,
initialize_module_for_quantization,
)

from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
freeze_module_quantization,
initialize_observer,
update_weight_global_scale,
update_weight_zp_scale,
)

__all__ = [
"initialize_quantized_linear",
"calibrate_weights",
"compress_module",
]


def initialize_quantized_linear(
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
) -> torch.nn.Module:
out_features, in_features = weight.shape
module = torch.nn.Linear(
in_features, out_features, bias=False, device=device, dtype=weight.dtype
)
module.weight.data.copy_(weight)
initialize_module_for_quantization(module, scheme, force_zero_point=False)

return module


def calibrate_weights(module: torch.nn.Linear):
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
initialize_observer(module, "weight")

apply_calibration_status(module)
if scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP:
update_weight_global_scale(module)
update_weight_zp_scale(module)

freeze_module_quantization(module)


def compress_module(module: torch.nn.Linear):
scheme: QuantizationScheme = getattr(module, "quantization_scheme")

format = _get_quant_compression_format(scheme.input_activations, scheme.weights)
scheme.format = format.value

compressor = BaseCompressor.load_from_registry(format.value)
data = compressor.compress_weight(
module.weight,
quantization_args=scheme.weights,
scale=getattr(module, "weight_scale"),
zero_point=getattr(module, "weight_zero_point", None),
global_scale=getattr(module, "weight_global_scale", None),
)

# `compress_weight` is a messy api
delattr(module, "weight")
for key, value in data.items():
if hasattr(module, key):
getattr(module, key).data = value
else:
module.register_parameter(
key, torch.nn.Parameter(value, requires_grad=False)
)
48 changes: 48 additions & 0 deletions src/llmcompressor/entrypoints/model_free/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from huggingface_hub import list_repo_files
from transformers.utils.hub import cached_file

__all__ = ["get_checkpoint_files", "is_weights_file"]

weights_files = [
".bin",
".safetensors",
".pth",
".msgpack",
".pt",
]


def is_weights_file(file_name: str) -> bool:
return any(file_name.endswith(suffix) for suffix in weights_files)


def get_checkpoint_files(model_stub: str | os.PathLike) -> list[str]:
# In the future, this function can accept and pass download kwargs to cached_file

if os.path.exists(model_stub):
file_paths = walk_file_paths(model_stub, ignore=".cache")
else:
file_paths = list_repo_files(model_stub)

return [(file_path, cached_file(model_stub, file_path)) for file_path in file_paths]


def walk_file_paths(root_dir: str, ignore: str | None = None) -> list[str]:
"""
Return all file paths relative to the root directory
"""

all_files = []
for dirpath, _, filenames in os.walk(root_dir):
for filename in filenames:
rel_path = os.path.relpath(os.path.join(dirpath, filename), root_dir)
if not (ignore and rel_path.startswith(ignore)):
all_files.append(rel_path)
return all_files


# distinguish relative file paths from absolute/resolved file paths
# relative file paths are used to find the save path
# resolved file paths are what are used to load data
Loading