Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] GPTQ Activation Ordering #8135

Merged
merged 11 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
condition on config
  • Loading branch information
kylesayrs committed Sep 5, 2024
commit a3df3488eed5d1914e3b80e0df2346753a068a03
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def _get_scheme_from_parts(
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just add the condition here?
actorder=weight_quant.actorder == ActivationOrdering.GROUP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but it feels more logical to me to keep argument processing within the CompressedTensorsWNA16.__init__ function. This separates responsibilities and makes clear that the job of _get_scheme_from_parts is to decide which compression scheme applies, not to process the arguments once the scheme is decided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd also have to rename the actorder argument of CompressedTensorsWNA16.__init__, otherwise it would be a misnomer


# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
Expand All @@ -28,11 +30,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):

self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP

if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
Expand Down Expand Up @@ -119,15 +123,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
dtype=torch.int64),
weight_loader=weight_loader)

# group index (for activation reordering)
weight_g_idx = BasevLLMParameter(data=torch.full(
(input_size_per_partition, ), -1, dtype=torch.int32),
weight_loader=weight_loader)

layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_g_idx", weight_g_idx)

# group index (for activation reordering)
if self.has_g_idx == ActivationOrdering.GROUP:
weight_g_idx = BasevLLMParameter(data=torch.full(
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
(input_size_per_partition, ), -1, dtype=torch.int32),
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
Expand All @@ -144,8 +149,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.output_size_per_partition, device)

# Handle sorting for activation reordering if needed.
has_g_idx = -1 not in layer.weight_g_idx
if has_g_idx:
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
Expand Down Expand Up @@ -174,7 +178,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=(layer.input_size
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
if has_g_idx else layer.input_size_per_partition),
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
Expand Down
Loading