Skip to content

Commit

Permalink
[Misc] Add per channel support for static activation quantization; up…
Browse files Browse the repository at this point in the history
…date w8a8 schemes to share base classes (vllm-project#5650)
  • Loading branch information
dsikka authored and jimpang committed Jul 8, 2024
1 parent 00981a3 commit 6155517
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 136 deletions.
14 changes: 10 additions & 4 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -33,12 +37,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):

assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
assert o_proj.weight.dtype is torch.int8
assert gate_up_proj.weight.dtype is torch.int8

assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
if qkv_proj.scheme.strategy == "tensor":
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,11 @@ def get_config_filenames(cls) -> List[str]:
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_tensor = (weight_quant.strategy == input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

Expand Down Expand Up @@ -131,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,

if self.quant_format == CompressionFormat.int_quantized.value:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()
return CompressedTensorsW8A8StaticTensor(
strategy=weight_quant.strategy)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Callable, List, Tuple, Union

import torch
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs


class CompressedTensorsW8A8(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
Original file line number Diff line number Diff line change
@@ -1,97 +1,28 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List

import torch
from torch.nn import Parameter

from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)

__all__ = ["CompressedTensorsW8A8DynamicToken"]


class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):

def __init__(self, strategy: str):
self.strategy = strategy

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

# When the scales have a single value, it is required that they be
# on the CPU for performance and CUDA Graphs compatibility. Please
# refer to the comment in
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
# when doing channel-wise quantization, number of scales
# is equal to output_dim
weight_scale_dim = sum(output_partition_sizes) if (
is_tensor_partitioned
or self.strategy == QuantizationStrategy.CHANNEL) else 1

shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
if self.strategy == QuantizationStrategy.CHANNEL:
shape = (weight_scale_dim, 1)

weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})

layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})

# Don't need a shard_splitter for channel-wise quantization
# Use the default loading method
if self.strategy == QuantizationStrategy.CHANNEL:
set_weight_attrs(weight_scale, {
"output_dim": 0,
})
else:
set_weight_attrs(
weight_scale, {
"logical_widths": output_partition_sizes,
"shard_splitter": self.scales_shard_splitter,
})
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down
Original file line number Diff line number Diff line change
@@ -1,79 +1,39 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List

import torch
from torch.nn import Parameter

from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
CompressedTensorsW8A8)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW8A8StaticTensor"]


class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):

def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
super().create_weights(
layer=layer,
output_partition_sizes=output_partition_sizes,
input_size_per_partition=input_size_per_partition,
params_dtype=params_dtype,
weight_loader=weight_loader)

input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)

weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)

layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"weight_loader": weight_loader,
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
Expand Down

0 comments on commit 6155517

Please sign in to comment.