Skip to content

[main]Refactoring w4a8 and w8a8 and supporting deepseek w4a8 #1469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
476 changes: 239 additions & 237 deletions .github/workflows/vllm_ascend_test.yaml

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions tests/e2e/multicard/test_model_qwen3_w4a8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""Compare the outputs of vLLM when using W4A8 quantization on qwen3 models.

Run `pytest tests/e2e/multicard/test_model_qwen3_w4a8.py`.
"""
import os

import pytest
from modelscope import snapshot_download # type: ignore
from vllm import LLM, SamplingParams

MODELS = ["vllm-ascend/Qwen3-8B-W4A8"]
PROMPTS = [
"Hello, my name is",
"The future of AI is",
]


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="w4a8_dynamic is not supported on v0")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [16])
def test_qwen3_model_with_w4a8_linear_method(model: str,
max_tokens: int) -> None:
messages = [[{"role": "user", "content": prompt}] for prompt in PROMPTS]
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0.0,
)
llm = LLM(
model=snapshot_download(model),
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=True,
quantization="ascend",
)
vllm_outputs = llm.chat(
messages,
sampling_params,
chat_template_kwargs={"enable_thinking": False},
)
golden_outputs = [
"Hello! My name is Qwen, and I'm a large language model developed",
"The future of AI is a topic of great interest and debate, with many possibilities",
]
assert len(vllm_outputs) == len(golden_outputs)
for vllm_output, golden_output in zip(vllm_outputs, golden_outputs):
assert vllm_output.outputs[0].text == golden_output
print(f"Generated text: {vllm_output.outputs[0].text!r}")
65 changes: 65 additions & 0 deletions tests/e2e/multicard/test_w4a8_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest.mock import patch

import pytest
from modelscope import snapshot_download # type: ignore

from tests.conftest import VllmRunner


model_name=snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning")


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="w4a8_dynamic is not supported on v0")
@patch.dict(os.environ, {"VLLM_USE_V1": "1", "VLLM_ASCEND_MLA_PA": "1"})
def test_deepseek_W4A8():
prompts = [
"The capital of France is",
"The future of AI is",
]
dtype = "bfloat16"
max_tokens = 5
with VllmRunner(
model_name,
dtype=dtype,
tensor_parallel_size=2,
quantization="ascend",
enforce_eager=True,
enable_expert_parallel=True,
additional_config={
"torchair_graph_config": {
"enabled": False,
},
"ascend_scheduler_config": {
"enabled": True,
}
},
) as vllm_model:
# use greedy sampler to make sure the generated results are fix
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)

golden_results = [
'The capital of France is逸 Ban Corporealistically',
'The future of AI is逸 Ban Corporealistically',
]
assert len(golden_results) == len(vllm_output)
for i in range(len(vllm_output)):
assert golden_results[i] == vllm_output[i][1]
print(f"Generated text: {vllm_output[i][1]!r}")
2 changes: 2 additions & 0 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,8 @@ def load_weights(self, weights: Iterable[tuple[str,
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "module" in name:
continue

spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
Expand Down
15 changes: 15 additions & 0 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ def create_weights(
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)

pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition, output_size_per_partition, params_dtype)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
setattr(param, "input_dim", 1)
param.input_dim = 1

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
Expand Down Expand Up @@ -305,6 +316,10 @@ def create_weights(
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
param.quant_method = FusedMoeWeightScaleSupported.GROUP.value

def apply(
self,
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w8a8 import AscendW8A8LinearMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
Expand Down Expand Up @@ -263,6 +265,17 @@ def get_quantizer(cls,
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")


class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):

@staticmethod
def build_linear_method():
return AscendW4A8DynamicLinearMethod()

@staticmethod
def build_moe_method():
return AscendW4A8DynamicFusedMoEMethod()


class W8A8Quantizer(VLLMAscendQuantizer):

@staticmethod
Expand All @@ -282,6 +295,7 @@ def build_moe_method():


SUPPORT_ASCEND_QUANTIZER_TYPE = {
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
"W8A8": W8A8Quantizer,
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
}
Loading
Loading