Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Apr 24, 2024
1 parent b840eae commit 6868f97
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_schema(self, weight_quant: Dict, input_quant: Dict):

is_8_bits = weight_bit == input_bit == 8
is_tensor = weight_strategy == input_strategy == "tensor"
is_symmetric = weight_symmetric == input_symmetric
is_symmetric = weight_symmetric and input_symmetric

if is_8_bits and is_tensor and is_symmetric:
return CompressedTensorsW8A8StaticTensor(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from abc import ABC, abstractmethod
import torch
import torch
from typing import Dict

__all__ = ["CompressedTensorsScheme"]


class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass of different
quantization schemes supported by CompressedTensors.
"""

@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError


@abstractmethod
def apply_weights(self, weights: Dict, x: torch.Tensor):
Expand All @@ -29,4 +30,4 @@ def apply_weights(self, weights: Dict, x: torch.Tensor):
:param x: input to the layer
"""
raise NotImplementedError
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
)
import torch
CompressedTensorsScheme)
import torch
from typing import Dict, List
from torch.nn import Parameter
from vllm.model_executor.utils import set_weight_attrs
import torch.nn.functional as F

__all__ = ["CompressedTensorsUnquantized"]


class CompressedTensorsUnquantized(CompressedTensorsScheme):
"""
Implements the scheme for all layers which are ignored in the CompressedTensors
config. The input and loaded weight are used in a linear transformation.
"""
def create_weights(self, output_sizes_per_partition: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, **kwargs):


def create_weights(self, output_sizes_per_partition: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, **kwargs):

weight = Parameter(torch.empty(sum(output_sizes_per_partition),
input_size_per_partition,
device="cuda",
Expand All @@ -26,7 +27,7 @@ def create_weights(self, output_sizes_per_partition: List[int],

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight}

def apply_weights(self, weights: Dict, x: torch.Tensor):
weight = weights.get("weight")
return F.linear(x, weight)
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import torch
import torch
from typing import Dict, List, Any, Union, Tuple
from vllm.model_executor.layers.quantization.compressed_tensors.cutlass_gemm import (
cutlass_gemm_dq
)
cutlass_gemm_dq)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
)
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
from torch.nn import Parameter
from vllm._C import ops

__all__ = ["CompressedTensorsW8A8StaticTensor"]


class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):

def __init__(self, fake_quant):
self.fake_quant = fake_quant
self.fake_quant = fake_quant

def _quantize(self, x: torch.Tensor, scales: torch.Tensor,
logical_widths: List[int], split_dim : int = 0) -> torch.Tensor:
def _quantize(self,
x: torch.Tensor,
scales: torch.Tensor,
logical_widths: List[int],
split_dim: int = 0) -> torch.Tensor:

x_q = torch.empty_like(x, dtype=torch.int8, device="cuda")
x_q_split = x_q.split(logical_widths, dim=split_dim)
Expand All @@ -38,28 +41,24 @@ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
return shard_id

assert isinstance(shard_id, str)
qkv_idxs = { "q": 0, "k": 1, "v": 2 }
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]

def scales_shard_splitter(self,
param: torch.Tensor,
loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset : offset + size], loaded_weight
return param[offset:offset + size], loaded_weight


def create_weights(self,
output_sizes_per_partition: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
**kwargs):
def create_weights(self, output_sizes_per_partition: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, **kwargs):

# TODO: remove zero_point parameters once the configs given remove them

Expand All @@ -72,18 +71,18 @@ def create_weights(self,
dtype=torch.float32),
requires_grad=False)
input_zero_point = Parameter(torch.empty(1,
device="cuda",
dtype=torch.int8),
requires_grad=False)
device="cuda",
dtype=torch.int8),
requires_grad=False)

weight_scale = Parameter(torch.empty(dim,
device="cuda",
dtype=torch.float32),
requires_grad=False)
device="cuda",
dtype=torch.float32),
requires_grad=False)
weight_zero_point = Parameter(torch.empty(1,
device="cuda",
dtype=torch.int8),
requires_grad=False)
device="cuda",
dtype=torch.int8),
requires_grad=False)

weight = Parameter(torch.empty(sum(output_sizes_per_partition),
input_size_per_partition,
Expand All @@ -92,9 +91,11 @@ def create_weights(self,
requires_grad=False)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight_scale,
{"shard_splitter" : self.scales_shard_splitter,
"logical_widths" : output_sizes_per_partition})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_sizes_per_partition
})

weights["weight"] = weight
weights["input_scale"] = input_scale
Expand All @@ -104,7 +105,6 @@ def create_weights(self,
weights["logical_widths"] = output_sizes_per_partition
return weights


def apply_weights(self, weights: Dict, x: torch.Tensor):
weight = weights.get("weight")
weight_scale = weights.get("weight_scale")
Expand All @@ -120,10 +120,12 @@ def apply_weights(self, weights: Dict, x: torch.Tensor):
# TODO : try not to remove device-to-host copy. i.e. keep the non-duplicated version
# of scales in the CPU
if self.fake_quant:
w_scales = [weight_scale[sum(logical_widths[:i])].item() for i in range(len(logical_widths))]
w_scales = [
weight_scale[sum(logical_widths[:i])].item()
for i in range(len(logical_widths))
]
w_scales = torch.FloatTensor(w_scales, device=torch.device("cpu"))
w_q = self._quantize(weight, w_scales, logical_widths)
# GEMM and dq
return cutlass_gemm_dq(x_q, w_q, x.dtype, weight_scale, act_scale)
return cutlass_gemm_dq(x_q, weight, x.dtype, weight_scale, act_scale)

0 comments on commit 6868f97

Please sign in to comment.