Skip to content

[CI] Pin transformers<4.53.0 and fix EPLB load_weights to make CI passed #1482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ requires = [
"scipy",
"setuptools>=64",
"setuptools-scm>=8",
"torch-npu==2.5.1.post1.dev20250528",
"torch-npu==2.5.1.post1.dev20250619",
"torch>=2.5.1",
"torchvision<0.21.0",
"wheel",
"msgpack",
"quart",
"numba",
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
"transformers<4.53.0",
]
build-backend = "setuptools.build_meta"
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ numba
--pre
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
torch-npu==2.5.1.post1.dev20250619

# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
transformers<4.53.0
117 changes: 110 additions & 7 deletions vllm_ascend/models/deepseek_dbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -49,16 +49,18 @@
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention)
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors

import vllm_ascend.envs as envs_ascend
Expand All @@ -76,7 +78,7 @@
make_multistream_metadata_ds)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import dispose_tensor
from vllm_ascend.utils import dispose_tensor, vllm_version_is

VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO

Expand Down Expand Up @@ -963,6 +965,107 @@
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [

Check warning on line 975 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L975

Added line #L975 was not covered by tests
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(

Check warning on line 983 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L983

Added line #L983 was not covered by tests
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

Check warning on line 993 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L989-L993

Added lines #L989 - L993 were not covered by tests

spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model

Check warning on line 997 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L995-L997

Added lines #L995 - L997 were not covered by tests

for (param_name, weight_name, shard_id) in stacked_params_mapping:

Check warning on line 999 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L999

Added line #L999 was not covered by tests
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue

Check warning on line 1002 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1001-L1002

Added lines #L1001 - L1002 were not covered by tests
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)

Check warning on line 1011 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1009-L1011

Added lines #L1009 - L1011 were not covered by tests
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

Check warning on line 1014 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1013-L1014

Added lines #L1013 - L1014 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 1017 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1016-L1017

Added lines #L1016 - L1017 were not covered by tests

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break

Check warning on line 1022 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1019-L1022

Added lines #L1019 - L1022 were not covered by tests
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

Check warning on line 1028 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1024-L1028

Added lines #L1024 - L1028 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 1031 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1030-L1031

Added lines #L1030 - L1031 were not covered by tests

param = params_dict[name]
weight_loader = param.weight_loader
if vllm_version_is("0.9.1"):
weight_loader(param,

Check warning on line 1036 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1033-L1036

Added lines #L1033 - L1036 were not covered by tests
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
else:
weight_loader(param,

Check warning on line 1042 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1042

Added line #L1042 was not covered by tests
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break

Check warning on line 1048 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1048

Added line #L1048 was not covered by tests
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

Check warning on line 1052 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1051-L1052

Added lines #L1051 - L1052 were not covered by tests

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

Check warning on line 1057 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1055-L1057

Added lines #L1055 - L1057 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 1060 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1059-L1060

Added lines #L1059 - L1060 were not covered by tests

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",

Check warning on line 1063 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1062-L1063

Added lines #L1062 - L1063 were not covered by tests
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

Check warning on line 1067 in vllm_ascend/models/deepseek_dbo.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_dbo.py#L1065-L1067

Added lines #L1065 - L1067 were not covered by tests

def forward(
self,
input_ids: torch.Tensor,
Expand Down
117 changes: 110 additions & 7 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch_npu
Expand Down Expand Up @@ -55,16 +55,18 @@
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer,
DeepseekV2MLAAttention)
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors

from vllm_ascend.ascend_config import get_ascend_config
Expand All @@ -73,7 +75,7 @@
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
npu_wait_tensor, vllm_version_is)


class CustomDeepseekV2SiluAndMul(SiluAndMul):
Expand Down Expand Up @@ -867,6 +869,107 @@
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [

Check warning on line 879 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L879

Added line #L879 was not covered by tests
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(

Check warning on line 887 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L887

Added line #L887 was not covered by tests
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

Check warning on line 897 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L893-L897

Added lines #L893 - L897 were not covered by tests

spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model

Check warning on line 901 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L899-L901

Added lines #L899 - L901 were not covered by tests

for (param_name, weight_name, shard_id) in stacked_params_mapping:

Check warning on line 903 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L903

Added line #L903 was not covered by tests
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue

Check warning on line 906 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L905-L906

Added lines #L905 - L906 were not covered by tests
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)

Check warning on line 915 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L913-L915

Added lines #L913 - L915 were not covered by tests
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

Check warning on line 918 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L917-L918

Added lines #L917 - L918 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 921 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L920-L921

Added lines #L920 - L921 were not covered by tests

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break

Check warning on line 926 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L923-L926

Added lines #L923 - L926 were not covered by tests
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

Check warning on line 932 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L928-L932

Added lines #L928 - L932 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 935 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L934-L935

Added lines #L934 - L935 were not covered by tests

param = params_dict[name]
weight_loader = param.weight_loader
if vllm_version_is("0.9.1"):
weight_loader(param,

Check warning on line 940 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L937-L940

Added lines #L937 - L940 were not covered by tests
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
else:
weight_loader(param,

Check warning on line 946 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L946

Added line #L946 was not covered by tests
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break

Check warning on line 952 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L952

Added line #L952 was not covered by tests
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

Check warning on line 956 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L955-L956

Added lines #L955 - L956 were not covered by tests

# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

Check warning on line 961 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L959-L961

Added lines #L959 - L961 were not covered by tests

if is_pp_missing_parameter(name, self):
continue

Check warning on line 964 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L963-L964

Added lines #L963 - L964 were not covered by tests

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",

Check warning on line 967 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L966-L967

Added lines #L966 - L967 were not covered by tests
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

Check warning on line 971 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L969-L971

Added lines #L969 - L971 were not covered by tests

def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading