Skip to content

Commit

Permalink
[BugFix] Fix weight loading for Mixtral with TP (vllm-project#2208)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Dec 20, 2023
1 parent 8320bfc commit d8865a9
Showing 1 changed file with 5 additions and 26 deletions.
31 changes: 5 additions & 26 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -94,30 +93,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return current_hidden_states


class DummyModule(nn.Module):

def __init__(self) -> None:
super().__init__()

self.w1 = nn.Linear(0, 0, bias=False)
self.w2 = nn.Linear(0, 0, bias=False)
self.w3 = nn.Linear(0, 0, bias=False)

set_weight_attrs(self.w1.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w2.weight,
{"weight_loader": self.dummy_weight_loader})
set_weight_attrs(self.w3.weight,
{"weight_loader": self.dummy_weight_loader})

def forward(self, *args, **kwargs) -> None:
raise NotImplementedError()

def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
# Noop
return


class MixtralMoE(nn.Module):

def __init__(
Expand Down Expand Up @@ -147,7 +122,7 @@ def __init__(
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else DummyModule()
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
Expand Down Expand Up @@ -427,6 +402,10 @@ def load_weights(self,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit d8865a9

Please sign in to comment.