forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Misc] Add per channel support for static activation quantization; up…
…date w8a8 schemes to share base classes (vllm-project#5650)
- Loading branch information
Showing
5 changed files
with
121 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
.../model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Callable, List, Tuple, Union | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
|
||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
QuantizationStrategy) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
|
||
class CompressedTensorsW8A8(CompressedTensorsScheme): | ||
|
||
def __init__(self, strategy: str): | ||
self.strategy = strategy | ||
|
||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: | ||
if isinstance(shard_id, int): | ||
return shard_id | ||
|
||
assert isinstance(shard_id, str) | ||
qkv_idxs = {"q": 0, "k": 1, "v": 2} | ||
assert shard_id in qkv_idxs | ||
return qkv_idxs[shard_id] | ||
|
||
def scales_shard_splitter( | ||
self, param: torch.Tensor, loaded_weight: torch.Tensor, | ||
shard_id: Union[str, int], | ||
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
shard_id = self._shard_id_as_int(shard_id) | ||
offset = sum(logical_widths[:shard_id]) | ||
size = logical_widths[shard_id] | ||
# update loaded weight with copies for broadcast. | ||
loaded_weight = loaded_weight.repeat(size) | ||
return param[offset:offset + size], loaded_weight | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
output_partition_sizes: List[int], | ||
input_size_per_partition: int, | ||
params_dtype: torch.dtype, weight_loader: Callable, | ||
**kwargs): | ||
|
||
is_tensor_partitioned = len(output_partition_sizes) != 1 | ||
weight_scale_dim = sum(output_partition_sizes) if ( | ||
is_tensor_partitioned | ||
or self.strategy == QuantizationStrategy.CHANNEL) else 1 | ||
|
||
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) | ||
if self.strategy == QuantizationStrategy.CHANNEL: | ||
shape = (weight_scale_dim, 1) | ||
|
||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), | ||
requires_grad=False) | ||
|
||
layer.register_parameter("weight_scale", weight_scale) | ||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) | ||
|
||
weight = Parameter(torch.empty(sum(output_partition_sizes), | ||
input_size_per_partition, | ||
dtype=torch.int8), | ||
requires_grad=False) | ||
|
||
layer.register_parameter("weight", weight) | ||
set_weight_attrs( | ||
weight, { | ||
"input_dim": 1, | ||
"output_dim": 0, | ||
"weight_loader": weight_loader, | ||
"logical_widths": output_partition_sizes | ||
}) | ||
|
||
# Don't need a shard_splitter for channel-wise quantization | ||
# Use the default loading method | ||
if self.strategy == QuantizationStrategy.CHANNEL: | ||
set_weight_attrs(weight_scale, { | ||
"output_dim": 0, | ||
}) | ||
else: | ||
set_weight_attrs( | ||
weight_scale, { | ||
"logical_widths": output_partition_sizes, | ||
"shard_splitter": self.scales_shard_splitter, | ||
}) |
89 changes: 10 additions & 79 deletions
89
...or/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 10 additions & 50 deletions
60
...or/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters