Skip to content

Some fixes for AWQ #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ def is_preset_scheme(name: str) -> bool:
),
)

# 4 bit integer weights only asymmetric quantization
W4A16_ASYM = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=False,
dynamic=False,
),
)

# 4 bit integer weights and 8 bit activations quantization
INT8_W4A8 = dict(
weights=QuantizationArgs(
Expand Down Expand Up @@ -205,6 +217,7 @@ def is_preset_scheme(name: str) -> bool:
# Integer weight only schemes
"W8A16": W8A16,
"W4A16": W4A16,
"W4A16_ASYM": W4A16_ASYM,
# Integer weight and activation schemes
"W8A8": INT8_W8A8,
"INT8": INT8_W8A8, # alias for W8A8
Expand Down
8 changes: 7 additions & 1 deletion src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ def calculate_qparams(
:param quantization_args: settings to quantization
:return: tuple of the calculated scale(s) and zero point(s)
"""
# based on the implementations for consuming quantized values,
# 0.0 must always be representable within the quantized range
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))

device = min_vals.device

bit_min, bit_max = calculate_range(quantization_args, device)
Expand All @@ -84,6 +87,9 @@ def calculate_qparams(
zero_points = torch.clamp(zero_points, bit_min, bit_max)

# match zero-points to quantized type
# if casting to int, use round instead of truncate
if quantization_args.type == QuantizationType.INT:
zero_points = torch.round(zero_points)
zero_points = zero_points.to(zp_dtype)

if scales.ndim == 0:
Expand All @@ -96,7 +102,7 @@ def calculate_qparams(
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
"""
Returns the computed scales and zero points for dynamic activation
qunatization.
quantization.

:param value: tensor to calculate quantization parameters for
:param args: quantization args
Expand Down