Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add per channel support for static activation quantization; update w8a8 schemes to share base classes #5650

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -33,12 +37,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8

assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
if qkv_proj.scheme.strategy == "tensor":
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ def get_config_filenames(cls) -> List[str]:
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_tensor = (weight_quant.strategy == input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

Expand Down Expand Up @@ -131,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,

if self.quant_format == CompressionFormat.int_quantized.value:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()
return CompressedTensorsW8A8StaticTensor(
strategy=weight_quant.strategy)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Callable, List, Tuple, Union

import torch
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs


class CompressedTensorsW8A8(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
Original file line number Diff line number Diff line change
@@ -1,97 +1,28 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List

import torch
from torch.nn import Parameter

from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)

__all__ = ["CompressedTensorsW8A8DynamicToken"]


class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

# When the scales have a single value, it is required that they be
# on the CPU for performance and CUDA Graphs compatibility. Please
# refer to the comment in
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
# when doing channel-wise quantization, number of scales
# is equal to output_dim
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down
Original file line number Diff line number Diff line change
@@ -1,79 +1,39 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List

import torch
from torch.nn import Parameter

from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A8StaticTensor"]


class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)

input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"weight_loader": weight_loader,
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down
Loading