Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Initial Activation Quantization Support #4525

Merged
merged 49 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
4d27a2c
Initial `CompressedTensors` config + Activation Quantization support …
dsikka Apr 30, 2024
92b3703
add get_quant method to compressed tensors config
dsikka Apr 30, 2024
2a3eb83
small rebase fixed
dsikka Apr 30, 2024
3dd1fe8
format
dsikka Apr 30, 2024
f2f8c52
fix mypy complaints
Apr 30, 2024
c9308eb
Merge branch 'main' into ds-quant
dsikka Apr 30, 2024
d9d49b5
format fixes
dsikka Apr 30, 2024
b111ee6
Merge branch 'main' into ds-quant
dsikka May 1, 2024
c31a7af
format fix post rebase
dsikka May 1, 2024
ca01b39
lazy import CompressedTensorsW8A8StaticTensor (#220)
varun-sundar-rabindranath May 1, 2024
f0197d4
lazy cutlass_gemm_dq import (#221)
varun-sundar-rabindranath May 1, 2024
4624b46
fix asm
May 1, 2024
75757d5
update shape change
dsikka May 2, 2024
e1df0eb
add todo
dsikka May 2, 2024
bc0991c
Rename quant_per_tensor -> static_scaled_int8_quant
May 2, 2024
74ad650
Remove cruft
May 2, 2024
43c43f3
Merge branch 'main' into ds-quant
dsikka May 14, 2024
cf5600f
fixes : typo
May 14, 2024
169ce7f
py-cutlass temporary hack for num_prompts==1
May 15, 2024
03b53e7
yapf
May 15, 2024
f9df31b
add test_int8_quant
May 16, 2024
ba4b6b3
call cpp cutlass
May 17, 2024
3c223c6
Merge branch 'main' into ds-quant
dsikka May 17, 2024
b27f31a
remove cutlass py interface
May 17, 2024
b589cdd
format.sh
May 17, 2024
98159cf
remove fake-quant
May 17, 2024
8dbeb31
add compressed tensors test
dsikka May 17, 2024
5eeb40a
remove torch.int8
dsikka May 17, 2024
c55e023
format
dsikka May 17, 2024
f5cbbd3
fix config parsing to match new model
dsikka May 20, 2024
a685957
revert parsing to use default pathway
dsikka May 20, 2024
4dfb37f
PR comments
dsikka May 21, 2024
de81f9e
Fix scales/zero-points device allocation
May 21, 2024
15f1863
ruff
May 21, 2024
bd53847
add better comments
May 21, 2024
b2926f3
add comment
dsikka May 22, 2024
1274386
Merge branch 'main' into ds-quant
dsikka May 22, 2024
18640c8
clang format
dsikka May 22, 2024
5c5dc84
clang format again
dsikka May 22, 2024
a44b4a0
address PR comments
May 22, 2024
6f0e6e1
clang-format
May 22, 2024
0090454
remove layer name
dsikka May 23, 2024
4b10fd7
remove unused import
dsikka May 23, 2024
68a59c7
remove parent name
dsikka May 23, 2024
b0afe67
Fix rounding
May 22, 2024
4f4951e
comment
May 23, 2024
869de3f
cruft
May 23, 2024
e68e391
yapf
May 23, 2024
d77cf50
remove unquantized check
dsikka May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format fixes
  • Loading branch information
dsikka committed Apr 30, 2024
commit d9d49b5224dccb16eb28628ed9fb5f95b07437cc
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from typing import List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -38,7 +38,7 @@ def create_weights(self,
output_size: int,
params_dtype: torch.dtype,
layer_name: Optional[str] = None,
**extra_weight_attrs) -> Dict[str, Any]:
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.

Expand Down Expand Up @@ -84,7 +84,7 @@ def create_weights(self,
output_size: int,
params_dtype: torch.dtype,
layer_name: Optional[str] = None,
**extra_weight_attrs) -> Dict[str, Any]:
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
Expand Down Expand Up @@ -413,7 +413,7 @@ def weight_loader(self,
param_data = param_data.narrow(0, shard_offset, shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the same thing as scale_shard_splitter we had for fp8 ... we can rename to match fp8

but yes this will be addressed by the refactor

elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths")
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

Expand Down Expand Up @@ -601,7 +601,7 @@ def weight_loader(self,
shard_size)
# If a param_shard_splitter is defined by the LinearMethod, use it.
elif param_shard_splitter is not None:
logical_widths = getattr(param, "logical_widths")
logical_widths = getattr(param, "logical_widths", None)
param_data, loaded_weight = param_shard_splitter(
param_data, loaded_weight, loaded_shard_id, logical_widths)

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8StaticTensor, CompressedTensorsUnquantized,
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8StaticTensor)


class CompressedTensorsConfig(QuantizationConfig):
Expand Down Expand Up @@ -138,8 +136,8 @@ def create_weights(self,
layer_name: Optional[str] = None,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create the
necessary parameters for the layer.
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer.
"""
weight_loader = extra_weight_attrs.get("weight_loader")

Expand All @@ -160,8 +158,9 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme associated with
the layer to apply the forward pass with the layer input.
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input.
"""

if bias is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, Optional, Tuple, Union

import cutlass
from cutlass import Tensor as FakeTensor
import cutlass.epilogue

import torch
from typing import Optional, Tuple, Dict, Union, Any
from cutlass import Tensor as FakeTensor

from vllm.logger import init_logger

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w8a8_statictensor import CompressedTensorsW8A8StaticTensor
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod

import torch

__all__ = ["CompressedTensorsScheme"]


class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass of different
quantization schemes supported by CompressedTensors.
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""

@abstractmethod
Expand All @@ -21,11 +22,11 @@ def create_weights(self, *args, **kwargs):
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
"""
Run the forward pass for the particular scheme. This is where scheme-specific
dequant/quant steps/kernels should be applied.
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.

:param layer: toch.nn.Module with the registered weights and other parameters
relevant to the particular scheme.
:param layer: toch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer

"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from typing import Callable, List

import torch
from typing import List, Callable
import torch.nn.functional as F
from torch.nn import Parameter

from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
import torch.nn.functional as F

__all__ = ["CompressedTensorsUnquantized"]


class CompressedTensorsUnquantized(CompressedTensorsScheme):
"""
Implements the scheme for all layers which are ignored in the CompressedTensors
config. The input and loaded weight are used in a linear transformation.
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""

def create_weights(self, layer: torch.nn.Module,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Callable, List, Tuple, Union

import torch
from typing import List, Union, Tuple, Callable
from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import (
from torch.nn import Parameter

from vllm._C import ops
from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import ( # noqa: E501
cutlass_gemm_dq)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
from torch.nn import Parameter
from vllm._C import ops

__all__ = ["CompressedTensorsW8A8StaticTensor"]

Expand Down Expand Up @@ -94,7 +96,7 @@ def create_weights(self, layer: torch.nn.Module,

layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
# Register parameter with the layer; register weight loader with each parameter

set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})

Expand Down Expand Up @@ -122,8 +124,8 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
x_q = self._quantize_single(x, act_scale[0].item())

# Weight quantize
# TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version
# of scales in the CPU
# TODO : try not to remove device-to-host copy.
# i.e. keep the non-duplicated version of scales in the CPU
if self.fake_quant:
w_scales = [
weight_scale[sum(logical_widths[:i])].item()
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def create_weights(
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_name: Optional[str] = None,
**extra_weight_attrs,
) -> None:
del output_size
Expand Down
Loading