Skip to content

Commit 04ea41e

Browse files
mzusmanweilong.yu
authored andcommitted
[Model] PP support for Mamba-like models (vllm-project#10992)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
1 parent 67965a2 commit 04ea41e

File tree

11 files changed

+229
-81
lines changed

11 files changed

+229
-81
lines changed

docs/source/models/supported_models.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Text Generation
128128
- FalconMamba
129129
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
130130
- ✅︎
131-
-
131+
- ✅︎
132132
* - :code:`GemmaForCausalLM`
133133
- Gemma
134134
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
@@ -193,7 +193,7 @@ Text Generation
193193
- Jamba
194194
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
195195
- ✅︎
196-
-
196+
- ✅︎
197197
* - :code:`LlamaForCausalLM`
198198
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
199199
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
@@ -203,7 +203,7 @@ Text Generation
203203
- Mamba
204204
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
205205
-
206-
-
206+
- ✅︎
207207
* - :code:`MiniCPMForCausalLM`
208208
- MiniCPM
209209
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.

tests/distributed/test_pipeline_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,13 @@ def iter_params(self, model_name: str):
156156
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
157157
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
158158
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
159-
# TODO: Implement PP
160-
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
159+
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
161160
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
162161
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
163162
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
164163
# Uses Llama
165164
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
165+
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
166166
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
167167
"mosaicml/mpt-7b": PPTestSettings.fast(),
168168
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
@@ -234,6 +234,8 @@ def iter_params(self, model_name: str):
234234
"OpenGVLab/InternVL2-1B",
235235
"microsoft/Phi-3-vision-128k-instruct",
236236
"fixie-ai/ultravox-v0_3",
237+
# [LANGUAGE GENERATION - HYBRID ARCH]
238+
"ai21labs/Jamba-tiny-dev",
237239
]
238240

239241

vllm/config.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
ConfigFormat, get_config, get_hf_image_processor_config,
2828
get_hf_text_config, get_pooling_config,
2929
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
30-
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
31-
print_warning_once, random_uuid,
30+
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
31+
get_cpu_memory, print_warning_once, random_uuid,
3232
resolve_obj_by_qualname)
3333

3434
if TYPE_CHECKING:
@@ -284,6 +284,7 @@ def __init__(
284284
self._verify_tokenizer_mode()
285285

286286
self.is_attention_free = self._init_attention_free()
287+
self.is_hybrid = self._init_is_hybrid()
287288
self.has_inner_state = self._init_has_inner_state()
288289

289290
if current_platform.is_neuron():
@@ -340,6 +341,10 @@ def _init_attention_free(self) -> bool:
340341
architectures = getattr(self.hf_config, "architectures", [])
341342
return ModelRegistry.is_attention_free_model(architectures)
342343

344+
def _init_is_hybrid(self) -> bool:
345+
architectures = getattr(self.hf_config, "architectures", [])
346+
return ModelRegistry.is_hybrid_model(architectures)
347+
343348
def _init_has_inner_state(self) -> bool:
344349
architectures = getattr(self.hf_config, "architectures", [])
345350
return ModelRegistry.model_has_inner_state(architectures)
@@ -669,26 +674,51 @@ def get_num_attention_heads(self,
669674
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
670675
return num_heads // parallel_config.tensor_parallel_size
671676

672-
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
677+
def get_layers_start_end_indices(
678+
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
673679
from vllm.distributed.utils import get_pp_indices
674680
total_num_hidden_layers = getattr(self.hf_text_config,
675681
"num_hidden_layers", 0)
676682
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
677683
pp_size = parallel_config.pipeline_parallel_size
678684
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
679-
return end - start
680-
681-
def get_num_attention_layers(self,
682-
parallel_config: "ParallelConfig") -> int:
683-
if self.is_attention_free:
684-
return 0
685+
return start, end
685686

686-
num_layers = self.get_num_layers(parallel_config)
687+
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
688+
start, end = self.get_layers_start_end_indices(parallel_config)
689+
return end - start
687690

688-
# Transformers supports layers_block_type @property
689-
layers = getattr(self.hf_config, "layers_block_type",
690-
["attention"] * num_layers)
691-
return len([t for t in layers if t == "attention"])
691+
def get_num_layers_by_block_type(
692+
self,
693+
parallel_config: "ParallelConfig",
694+
block_type: LayerBlockType = LayerBlockType.attention,
695+
) -> int:
696+
# This function relies on 'layers_block_type' in hf_config,
697+
# for w/o this attribute, we will need to have workarounds like so
698+
attn_block_type = block_type == LayerBlockType.attention
699+
is_transformer = not self.is_hybrid and not self.is_attention_free
700+
start, end = self.get_layers_start_end_indices(parallel_config)
701+
702+
if is_transformer:
703+
# Handle the basic case first
704+
return end - start if attn_block_type else 0
705+
elif self.is_attention_free:
706+
# Attention free
707+
# Note that this code assumes there
708+
# is only one type of attention-free block type.
709+
return 0 if attn_block_type else end - start
710+
else:
711+
# Hybrid model
712+
layers_block_type_value = getattr(self.hf_config,
713+
"layers_block_type", None)
714+
if layers_block_type_value is None:
715+
raise ValueError("The model is an hybrid without a"
716+
"layers_block_type in the hf_config,"
717+
"cannot determine the num of "
718+
f"{block_type.value} layers")
719+
720+
return sum(t == block_type.value
721+
for t in layers_block_type_value[start:end])
692722

693723
def get_multimodal_config(self) -> "MultiModalConfig":
694724
"""

vllm/model_executor/models/interfaces.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,43 @@ def is_attention_free(
363363
return isinstance(model, IsAttentionFree)
364364

365365

366+
@runtime_checkable
367+
class IsHybrid(Protocol):
368+
"""The interface required for all models like Jamba that have both
369+
attention and mamba blocks, indicates that
370+
hf_config has 'layers_block_type'"""
371+
372+
is_hybrid: ClassVar[Literal[True]] = True
373+
"""
374+
A flag that indicates this model has both mamba and attention blocks
375+
, also indicates that the model's hf_config has
376+
'layers_block_type' """
377+
378+
379+
@runtime_checkable
380+
class _IsHybridType(Protocol):
381+
is_hybrid: ClassVar[Literal[True]]
382+
383+
384+
@overload
385+
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
386+
...
387+
388+
389+
@overload
390+
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
391+
...
392+
393+
394+
def is_hybrid(
395+
model: Union[Type[object], object]
396+
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
397+
if isinstance(model, type):
398+
return isinstance(model, _IsHybridType)
399+
400+
return isinstance(model, IsHybrid)
401+
402+
366403
@runtime_checkable
367404
class SupportsCrossEncoding(Protocol):
368405
"""The interface required for all models that support cross encoding."""

vllm/model_executor/models/jamba.py

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.attention.layer import Attention
1010
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
1111
from vllm.distributed import get_tensor_model_parallel_world_size
12+
from vllm.distributed.parallel_state import get_pp_group
1213
from vllm.model_executor.layers.fused_moe import FusedMoE
1314
from vllm.model_executor.layers.layernorm import RMSNorm
1415
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -25,9 +26,12 @@
2526
MambaCacheParams)
2627
from vllm.model_executor.sampling_metadata import SamplingMetadata
2728
from vllm.sequence import IntermediateTensors
29+
from vllm.utils import LayerBlockType
2830

29-
from .interfaces import HasInnerState, SupportsLoRA
30-
from .utils import maybe_prefix
31+
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
32+
from .utils import (is_pp_missing_parameter,
33+
make_empty_intermediate_tensors_factory, make_layers,
34+
maybe_prefix)
3135

3236
KVCache = Tuple[torch.Tensor, torch.Tensor]
3337

@@ -281,16 +285,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281285
org_num_embeddings=config.vocab_size,
282286
)
283287

284-
decoder_layers = []
285-
for i in range(config.num_hidden_layers):
286-
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
287-
decoder_layers.append(
288-
layer_class(config,
289-
layer_idx=i,
290-
cache_config=cache_config,
291-
quant_config=quant_config,
292-
prefix=f"{prefix}.layers.{i}"))
293-
self.layers = nn.ModuleList(decoder_layers)
288+
def get_layer(prefix: str):
289+
layer_idx = int(prefix.rsplit(".", 1)[1])
290+
layer_class = ALL_DECODER_LAYER_TYPES[
291+
config.layers_block_type[layer_idx]]
292+
return layer_class(
293+
config,
294+
layer_idx,
295+
cache_config,
296+
quant_config=quant_config,
297+
prefix=prefix,
298+
)
299+
300+
self.start_layer, self.end_layer, self.layers = make_layers(
301+
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
302+
self.make_empty_intermediate_tensors = (
303+
make_empty_intermediate_tensors_factory(
304+
["hidden_states", "residual"], config.hidden_size))
305+
294306
self.final_layernorm = RMSNorm(config.hidden_size,
295307
eps=config.rms_norm_eps)
296308

@@ -304,26 +316,34 @@ def forward(
304316
kv_caches: List[torch.Tensor],
305317
attn_metadata: AttentionMetadata,
306318
mamba_cache_params: MambaCacheParams,
319+
intermediate_tensors: Optional[IntermediateTensors] = None,
307320
inputs_embeds: Optional[torch.Tensor] = None,
308321
) -> torch.Tensor:
309-
if inputs_embeds is not None:
310-
hidden_states = inputs_embeds
322+
if get_pp_group().is_first_rank:
323+
if inputs_embeds is not None:
324+
hidden_states = inputs_embeds
325+
else:
326+
hidden_states = self.get_input_embeddings(input_ids)
327+
residual = None
311328
else:
312-
hidden_states = self.get_input_embeddings(input_ids)
313-
residual = None
314-
for i in range(len(self.layers)):
329+
assert intermediate_tensors is not None
330+
hidden_states = intermediate_tensors["hidden_states"]
331+
residual = intermediate_tensors["residual"]
332+
333+
kv_cache_index = 0
334+
mamba_cache_index = 0
335+
for i in range(self.start_layer, self.end_layer):
315336
layer = self.layers[i]
316337
kv_cache = None
317338
layer_mamba_cache_params = None
318339
if isinstance(layer, JambaAttentionDecoderLayer):
319-
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
320-
self.config.attn_layer_period]
340+
kv_cache = kv_caches[kv_cache_index]
341+
kv_cache_index += 1
321342
if isinstance(layer, JambaMambaDecoderLayer):
322-
current_state_layer = i - (1 +
323-
(i - self.config.attn_layer_offset)
324-
// self.config.attn_layer_period)
343+
current_state_layer = mamba_cache_index
325344
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
326345
current_state_layer)
346+
mamba_cache_index += 1
327347

328348
hidden_states, residual = layer(
329349
positions=positions,
@@ -332,11 +352,17 @@ def forward(
332352
attn_metadata=attn_metadata,
333353
residual=residual,
334354
mamba_cache_params=layer_mamba_cache_params)
355+
if not get_pp_group().is_last_rank:
356+
return IntermediateTensors({
357+
"hidden_states": hidden_states,
358+
"residual": residual
359+
})
335360
hidden_states, _ = self.final_layernorm(hidden_states, residual)
336361
return hidden_states
337362

338363

339-
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
364+
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
365+
IsHybrid):
340366
packed_modules_mapping = {
341367
"qkv_proj": [
342368
"q_proj",
@@ -368,6 +394,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
368394

369395
super().__init__()
370396
self.config = config
397+
self.vllm_config = vllm_config
398+
self.model_config = vllm_config.model_config
371399
self.scheduler_config = scheduler_config
372400
self.model = JambaModel(vllm_config=vllm_config,
373401
prefix=maybe_prefix(prefix, "model"))
@@ -390,6 +418,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
390418
config.vocab_size)
391419
self.sampler = get_sampler()
392420

421+
self.make_empty_intermediate_tensors = (
422+
self.model.make_empty_intermediate_tensors)
423+
393424
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
394425
return self.model.get_input_embeddings(input_ids)
395426

@@ -406,10 +437,8 @@ def forward(self,
406437
self.scheduler_config.max_num_seqs) if self.scheduler_config
407438
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
408439

409-
layers_type = self.config.layers_block_type
410-
num_mamba_layers = sum(
411-
[layer_type == "mamba" for layer_type in layers_type])
412-
440+
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
441+
self.vllm_config.parallel_config, LayerBlockType.mamba)
413442
self.mamba_cache = MambaCacheManager(
414443
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
415444
*self._get_mamba_cache_shape())
@@ -423,7 +452,7 @@ def forward(self,
423452
state_indices_tensor)
424453
hidden_states = self.model(input_ids, positions, kv_caches,
425454
attn_metadata, mamba_cache_params,
426-
inputs_embeds)
455+
intermediate_tensors, inputs_embeds)
427456
return hidden_states
428457

429458
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
@@ -504,8 +533,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
504533
continue
505534
name = name.replace(weight_name, param_name)
506535
# Skip loading extra bias for GPTQ models.
536+
507537
if name.endswith(".bias") and name not in params_dict:
508538
continue
539+
# Skip layers on other devices.
540+
if is_pp_missing_parameter(name, self):
541+
continue
509542
param = params_dict[name]
510543
weight_loader = param.weight_loader
511544
weight_loader(param, loaded_weight, shard_id)
@@ -520,6 +553,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
520553
if weight_name not in name:
521554
continue
522555

556+
if is_pp_missing_parameter(name, self):
557+
continue
523558
name = name.replace(weight_name, param_name)
524559
param = params_dict[name]
525560
weight_loader = param.weight_loader
@@ -533,6 +568,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
533568
# Skip loading extra bias for GPTQ models.
534569
if name.endswith(".bias") and name not in params_dict:
535570
continue
571+
if is_pp_missing_parameter(name, self):
572+
continue
536573

537574
param = params_dict[name]
538575
weight_loader = getattr(param, "weight_loader",

0 commit comments

Comments
 (0)