-
-
Notifications
You must be signed in to change notification settings - Fork 9.1k
[Model] Add support for GraniteMoeShared models #13313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DarkLight1337
merged 11 commits into
vllm-project:main
from
tjohnson31415:granite-shared-experts
Mar 4, 2025
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1268cb5
first draft with shared experts support
tjohnson31415 6273278
fix: use getattr and delete tensor early
tjohnson31415 15524e4
fix: use MergedColumnParallelLinear
tjohnson31415 1e8fac0
refactor: move impl to dedicated granitesharedmoe.py file
tjohnson31415 10f6a39
update supported_lora_modules
tjohnson31415 4ed7972
linting
tjohnson31415 dc5effc
review: updates from code review
tjohnson31415 56d69cf
docs: update docs for granitemoeshared
tjohnson31415 3ed5c26
Update tests/models/registry.py
DarkLight1337 4467c5d
rebase: some modeling code cleanup after rebase
tjohnson31415 8630432
fix: graniteMoE does support PP (but requires enforce-eager)
tjohnson31415 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Inference-only GraniteMoeShared model. | ||
|
||
The architecture is the same as granitemoe but with the addition of shared | ||
experts. | ||
""" | ||
from typing import Iterable, Optional, Set, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
from transformers.models.granitemoeshared import GraniteMoeSharedConfig | ||
|
||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import CacheConfig, VllmConfig | ||
from vllm.distributed import get_pp_group | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from . import mixtral | ||
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE | ||
from .interfaces import SupportsLoRA, SupportsPP | ||
from .utils import make_layers, maybe_prefix | ||
|
||
|
||
class GraniteMoeSharedMLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: GraniteMoeSharedConfig, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
): | ||
super().__init__() | ||
|
||
self.input_size = config.hidden_size | ||
self.hidden_size = config.shared_intermediate_size | ||
self.input_linear = MergedColumnParallelLinear( | ||
input_size=self.input_size, | ||
output_sizes=[self.hidden_size] * 2, | ||
bias=False, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.input_linear") | ||
self.output_linear = RowParallelLinear( | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.hidden_size, | ||
self.input_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.output_linear") | ||
if config.hidden_act != "silu": | ||
raise ValueError(f"Unsupported activation: {config.hidden_act}. " | ||
"Only silu is supported for now.") | ||
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
hidden_states, _ = self.input_linear(hidden_states) | ||
hidden_states = self.act_fn(hidden_states) | ||
hidden_states, _ = self.output_linear(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class GraniteMoeSharedDecoderLayer(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: GraniteMoeSharedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
# Requires transformers > 4.32.0 | ||
rope_theta = getattr(config, "rope_theta", 10000) | ||
self.self_attn = GraniteMoeAttention( | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
max_position=config.max_position_embeddings, | ||
num_kv_heads=config.num_key_value_heads, | ||
rope_theta=rope_theta, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.self_attn", | ||
attention_multiplier=config.attention_multiplier) | ||
self.block_sparse_moe = GraniteMoeMoE( | ||
num_experts=config.num_local_experts, | ||
top_k=config.num_experts_per_tok, | ||
hidden_size=config.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.block_sparse_moe") | ||
self.shared_mlp = None if \ | ||
getattr(config, 'shared_intermediate_size', 0) == 0 \ | ||
else GraniteMoeSharedMLP( | ||
config, | ||
quant_config=quant_config, | ||
prefix=f"{prefix}.shared_mlp" | ||
) | ||
|
||
self.input_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
|
||
self.residual_multiplier = config.residual_multiplier | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
# Self Attention | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
hidden_states = self.self_attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
) | ||
hidden_states = residual + hidden_states * self.residual_multiplier | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
if self.shared_mlp is None: | ||
hidden_states = self.block_sparse_moe(hidden_states) | ||
else: | ||
# create a copy since block_sparse_moe modifies in-place | ||
moe_hidden_states = hidden_states.clone() | ||
moe_hidden_states = self.block_sparse_moe(moe_hidden_states) | ||
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) | ||
del moe_hidden_states | ||
hidden_states = residual + hidden_states * self.residual_multiplier | ||
|
||
return hidden_states | ||
|
||
|
||
@support_torch_compile | ||
class GraniteMoeSharedModel(nn.Module): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
|
||
config = vllm_config.model_config.hf_config | ||
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
|
||
self.padding_idx = config.pad_token_id | ||
lora_vocab = (lora_config.lora_extra_vocab_size * | ||
(lora_config.max_loras or 1)) if lora_config else 0 | ||
self.vocab_size = config.vocab_size + lora_vocab | ||
self.org_vocab_size = config.vocab_size | ||
|
||
self.embed_tokens = VocabParallelEmbedding( | ||
self.vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
quant_config=quant_config, | ||
) | ||
self.embedding_multiplier = config.embedding_multiplier | ||
|
||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: GraniteMoeSharedDecoderLayer( | ||
config, cache_config, quant_config=quant_config, prefix=prefix | ||
), | ||
prefix=f"{prefix}.layers") | ||
|
||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.embed_tokens(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
intermediate_tensors: Optional[IntermediateTensors], | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
if get_pp_group().is_first_rank: | ||
if inputs_embeds is not None: | ||
hidden_states = inputs_embeds | ||
else: | ||
hidden_states = self.get_input_embeddings(input_ids) | ||
hidden_states *= self.embedding_multiplier | ||
residual = None | ||
else: | ||
assert intermediate_tensors is not None | ||
hidden_states = intermediate_tensors["hidden_states"] | ||
residual = intermediate_tensors["residual"] | ||
for i in range(self.start_layer, self.end_layer): | ||
layer = self.layers[i] | ||
hidden_states = layer(positions, hidden_states) | ||
if not get_pp_group().is_last_rank: | ||
return IntermediateTensors({ | ||
"hidden_states": hidden_states, | ||
"residual": residual | ||
}) | ||
hidden_states = self.norm(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): | ||
fall_back_to_pt_during_load = False | ||
|
||
packed_modules_mapping = { | ||
"qkv_proj": [ | ||
"q_proj", | ||
"k_proj", | ||
"v_proj", | ||
], | ||
} | ||
|
||
# LoRA specific attributes | ||
embedding_modules = { | ||
"embed_tokens": "input_embeddings", | ||
"lm_head": "output_embeddings", | ||
} | ||
embedding_padding_modules = ["lm_head"] | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
config = vllm_config.model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
|
||
self.config = config | ||
self.lora_config = lora_config | ||
self.quant_config = quant_config | ||
|
||
self.model = GraniteMoeSharedModel(vllm_config=vllm_config, | ||
prefix=maybe_prefix( | ||
prefix, "model")) | ||
self.unpadded_vocab_size = config.vocab_size | ||
if lora_config: | ||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size | ||
self.lm_head = ParallelLMHead( | ||
self.unpadded_vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
padding_size=DEFAULT_VOCAB_PADDING_SIZE | ||
# We need bigger padding if using lora for kernel | ||
# compatibility | ||
if not lora_config else lora_config.lora_vocab_padding_size, | ||
quant_config=quant_config, | ||
tjohnson31415 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prefix=maybe_prefix(prefix, "lm_head")) | ||
if config.tie_word_embeddings: | ||
self.lm_head.weight = self.model.embed_tokens.weight | ||
|
||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, | ||
config.vocab_size, | ||
scale=1 / | ||
self.config.logits_scaling) | ||
|
||
self.sampler = get_sampler() | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
return self.model.get_input_embeddings(input_ids) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
hidden_states = self.model(input_ids, positions, intermediate_tensors, | ||
inputs_embeds) | ||
return hidden_states | ||
|
||
def compute_logits( | ||
self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: | ||
logits = self.logits_processor(self.lm_head, hidden_states, | ||
sampling_metadata) | ||
return logits | ||
|
||
def make_empty_intermediate_tensors( | ||
self, batch_size: int, dtype: torch.dtype, | ||
device: torch.device) -> IntermediateTensors: | ||
return IntermediateTensors({ | ||
"hidden_states": | ||
torch.zeros((batch_size, self.config.hidden_size), | ||
dtype=dtype, | ||
device=device), | ||
"residual": | ||
torch.zeros((batch_size, self.config.hidden_size), | ||
dtype=dtype, | ||
device=device), | ||
}) | ||
|
||
def sample( | ||
self, | ||
logits: Optional[torch.Tensor], | ||
sampling_metadata: SamplingMetadata, | ||
) -> Optional[SamplerOutput]: | ||
next_tokens = self.sampler(logits, sampling_metadata) | ||
return next_tokens | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, | ||
torch.Tensor]]) -> Set[str]: | ||
new_weights = {} | ||
for n, p in weights: | ||
if n.endswith('.block_sparse_moe.input_linear.weight'): | ||
for e in range(p.size(0)): | ||
w1_name = n.replace( | ||
'.block_sparse_moe.input_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w1.weight") | ||
w3_name = n.replace( | ||
'.block_sparse_moe.input_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w3.weight") | ||
w1_param, w3_param = p[e].chunk(2, dim=0) | ||
assert w1_name not in new_weights | ||
assert w3_name not in new_weights | ||
new_weights[w1_name] = w1_param | ||
new_weights[w3_name] = w3_param | ||
elif n.endswith('.block_sparse_moe.output_linear.weight'): | ||
for e in range(p.size(0)): | ||
w2_name = n.replace( | ||
'.block_sparse_moe.output_linear.weight', | ||
f".block_sparse_moe.experts.{e}.w2.weight") | ||
w2_param = p[e] | ||
assert w2_name not in new_weights | ||
new_weights[w2_name] = w2_param | ||
elif n.endswith('.block_sparse_moe.router.layer.weight'): | ||
gate_name = n.replace('.block_sparse_moe.router.layer.weight', | ||
".block_sparse_moe.gate.weight") | ||
assert gate_name not in new_weights | ||
new_weights[gate_name] = p | ||
elif n == 'lm_head.weight' and self.config.tie_word_embeddings: | ||
pass | ||
else: | ||
new_weights[n] = p | ||
return mixtral.MixtralForCausalLM.load_weights(self, | ||
new_weights.items()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: why doesn't input_linear support LoRA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look!
Honestly, I don't know what is required for a layer to support LoRA... I presume that there is no reason for a simple linear layer not to, but do please let me know if there are reasons I would need to investigate 😅
I added
input_linear
andoutput_linear
to thesupported_lora_modules
.