Skip to content

Commit

Permalink
[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear meth…
Browse files Browse the repository at this point in the history
…od (vllm-project#7271)

Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
ProExpertProg authored and Alvant committed Oct 26, 2024
1 parent 55ea39f commit c0b18d9
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.764
- name: "exact_match,flexible-extract"
value: 0.764
limit: 250
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-small.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base-FP8.yaml
Expand Down
7 changes: 6 additions & 1 deletion .buildkite/lm-eval-harness/test_lm_eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ def test_lm_eval_correctness():
results = launch_lm_eval(eval_config)

# Confirm scores match ground truth.
success = True
for task in eval_config["tasks"]:
for metric in task["metrics"]:
ground_truth = metric["value"]
measured_value = results["results"][task["name"]][metric["name"]]
print(f'{task["name"]} | {metric["name"]}: '
f'ground_truth={ground_truth} | measured={measured_value}')
assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
success = success and numpy.isclose(
ground_truth, measured_value, rtol=RTOL)

# Assert at the end, print all scores even on failure for debugging.
assert success
36 changes: 27 additions & 9 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from typing import Optional

import pytest
import torch
Expand All @@ -14,14 +15,16 @@
QuantizationType)


@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560),
])
@pytest.mark.parametrize(
"model_args",
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
QuantizationType.INT, 2560, True),
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
QuantizationType.INT, 2560, True),
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
QuantizationType.INT, 2560, False)])
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0 = model_args
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
Expand All @@ -31,6 +34,18 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj

# assert zp for symmetric and asymmetric cases
def zp_valid(zp: Optional[torch.Tensor]):
if is_symmetric:
return zp is None

return zp is not None and zp.dtype is torch.int32

assert zp_valid(qkv_proj.input_zero_point)
assert zp_valid(o_proj.input_zero_point)
assert zp_valid(gate_up_proj.input_zero_point)
assert zp_valid(down_proj.input_zero_point)

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method,
Expand Down Expand Up @@ -69,9 +84,12 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):

@pytest.mark.parametrize("model_args", [
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
"channel"),
])
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
Expand Down Expand Up @@ -160,4 +178,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
assert output
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic

return is_8_bits and is_tensor and is_symmetric and is_static
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static

def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand All @@ -151,10 +152,11 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

return is_8_bits and is_token and is_symmetric and is_dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -265,12 +267,14 @@ def _get_scheme_from_parts(
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True)
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)

if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False)
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
Expand All @@ -14,12 +15,16 @@
ModelWeightParameter,
PerTensorScaleParameter)

logger = init_logger(__name__)


class CompressedTensorsW8A8Int8(CompressedTensorsScheme):

def __init__(self, strategy: str, is_static_input_scheme: bool):
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric

@classmethod
def get_min_capability(cls) -> int:
Expand All @@ -46,10 +51,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_zero_point = None
else:
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = layer.input_zero_point.to(dtype=torch.int32)
range_max = (layer.input_scale *
(int8_traits.max - azps)).max()
range_min = (layer.input_scale *
(int8_traits.min - azps)).min()

scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)

# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
layer.input_zero_point = Parameter(azp, requires_grad=False)

else:
layer.input_scale = None
layer.input_zero_point = None

# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
layer.azp_adj = layer.weight.sum(dim=0,
keepdim=True,
dtype=torch.int32)
else:
layer.azp_adj = None

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
Expand Down Expand Up @@ -90,11 +128,22 @@ def create_weights(self, layer: torch.nn.Module,
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)
19 changes: 17 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,28 @@ def apply_int8_linear(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_zero_point: Optional[torch.Tensor] = None,
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)

symmetric = azp_adj is None
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)

if x_zp is not None:
return ops.cutlass_scaled_mm_azp(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
Expand Down

0 comments on commit c0b18d9

Please sign in to comment.