Skip to content

Commit

Permalink
Added asymmetric integration to linear layers
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg committed Sep 18, 2024
1 parent d9cd78e commit b620aad
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_symmetric = weight_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

return is_8_bits and is_tensor and is_symmetric and is_static
Expand All @@ -151,7 +151,7 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_symmetric = weight_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token and is_symmetric and is_dynamic
Expand Down Expand Up @@ -265,12 +265,14 @@ def _get_scheme_from_parts(
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

class CompressedTensorsW8A8Int8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -48,8 +50,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
if not self.input_symmetric:
layer.input_zero_point = Parameter(layer.input_zero_point,
requires_grad=False)
else:
layer.input_zero_point = None
else:
layer.input_scale = None
layer.input_zero_point = None

if not self.input_symmetric:
layer.azp_adj = layer.weight.sum(dim=0,
keepdim=True,
dtype=torch.int32)
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
Expand Down Expand Up @@ -90,11 +105,18 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

if not self.input_symmetric:
raise NotImplementedError(
"static input asymmetric quantization not supported yet")
input_zero_point = Parameter(torch.zeros(1, dtype=torch.int8))
layer.register_parameter("input_zero_point", input_zero_point)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
19 changes: 17 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,28 @@ def apply_int8_linear(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)

symmetric = azp_adj is None
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)

if x_zp is not None:
return ops.cutlass_scaled_mm_azp(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
Expand Down

0 comments on commit b620aad

Please sign in to comment.