Skip to content

Commit 988718e

Browse files
committed
Jamba official hf (vllm-project#14)
* remove JambaConfig and use official one from transformers * changes in Jamba modeling file to align with official HF format
1 parent af7a4ac commit 988718e

File tree

2 files changed

+89
-126
lines changed

2 files changed

+89
-126
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 87 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
from typing import Iterable, List, Optional, Tuple
55

66
import torch
7-
from transformers import JambaConfig
8-
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
9-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
10-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
117
from torch import nn
12-
from torch.nn.parameter import Parameter
138

9+
from vllm.model_executor.layers.activation import SiluAndMul
1410
from vllm.attention.backends.abstract import AttentionMetadata
1511
from vllm.attention.layer import Attention
12+
13+
from transformers import JambaConfig
14+
from torch.nn.parameter import Parameter
1615
from vllm.config import LoRAConfig
1716
from vllm.distributed import (get_tensor_model_parallel_rank,
1817
get_tensor_model_parallel_world_size,
@@ -33,6 +32,9 @@
3332
from vllm.model_executor.utils import set_weight_attrs
3433
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3534
from vllm.sequence import SamplerOutput
35+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
36+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
37+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
3638

3739
KVCache = Tuple[torch.Tensor, torch.Tensor]
3840

@@ -43,7 +45,6 @@ class MambaCacheParams:
4345
ssm_state: torch.Tensor = torch.Tensor()
4446

4547

46-
4748
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
4849
class JambaMambaMixer(nn.Module):
4950
"""
@@ -124,28 +125,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
124125
input_is_parallel=True,
125126
)
126127
self.activation = config.hidden_act
127-
self.apply_inner_layernorms = config.mamba_inner_layernorms
128-
129-
if self.apply_inner_layernorms:
130-
self.dt_layernorm = RMSNorm(self.time_step_rank,
131-
eps=config.rms_norm_eps)
132-
self.B_layernorm = RMSNorm(self.ssm_state_size,
133-
eps=config.rms_norm_eps)
134-
self.C_layernorm = RMSNorm(self.ssm_state_size,
135-
eps=config.rms_norm_eps)
136-
else:
137-
self.dt_layernorm = None
138-
self.B_layernorm = None
139-
self.C_layernorm = None
140-
141-
def _apply_layernorms(self, dt, B, C):
142-
if self.dt_layernorm is not None:
143-
dt = self.dt_layernorm.forward(dt.contiguous())
144-
if self.B_layernorm is not None:
145-
B = self.B_layernorm.forward(B.contiguous())
146-
if self.C_layernorm is not None:
147-
C = self.C_layernorm.forward(C.contiguous())
148-
return dt, B, C
128+
129+
self.dt_layernorm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
130+
self.b_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
131+
self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
149132

150133
def mamba_forward(self,
151134
hidden_states: torch.Tensor,
@@ -189,7 +172,9 @@ def mamba_forward(self,
189172
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
190173
dim=-1,
191174
)
192-
time_step, B, C = self._apply_layernorms(time_step, B, C)
175+
time_step = self.dt_layernorm(time_step.contiguous())
176+
B = self.b_layernorm(B.contiguous())
177+
C = self.c_layernorm(C.contiguous())
193178

194179
discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
195180
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
@@ -275,6 +260,36 @@ def forward(
275260
return hidden_states
276261

277262

263+
class JambaMLP(nn.Module):
264+
def __init__(
265+
self,
266+
config: JambaConfig,
267+
quant_config: Optional[QuantizationConfig] = None,
268+
) -> None:
269+
super().__init__()
270+
hidden_size = config.hidden_size
271+
intermediate_size = config.intermediate_size
272+
hidden_act = config.hidden_act
273+
self.gate_up_proj = MergedColumnParallelLinear(
274+
hidden_size, [intermediate_size] * 2,
275+
bias=False,
276+
quant_config=quant_config)
277+
self.down_proj = RowParallelLinear(intermediate_size,
278+
hidden_size,
279+
bias=False,
280+
quant_config=quant_config)
281+
if hidden_act != "silu":
282+
raise ValueError(f"Unsupported activation: {hidden_act}. "
283+
"Only silu is supported for now.")
284+
self.act_fn = SiluAndMul()
285+
286+
def forward(self, x):
287+
gate_up, _ = self.gate_up_proj(x)
288+
x = self.act_fn(gate_up)
289+
x, _ = self.down_proj(x)
290+
return x
291+
292+
278293
class JambaMoE(nn.Module):
279294
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
280295
across all ranks.
@@ -285,33 +300,27 @@ class JambaMoE(nn.Module):
285300
"""
286301

287302
def __init__(
288-
self,
289-
num_experts: int,
290-
top_k: int,
291-
hidden_size: int,
292-
intermediate_size: int,
293-
params_dtype: Optional[torch.dtype] = None,
294-
tp_size: Optional[int] = None,
303+
self,
304+
config: JambaConfig,
305+
params_dtype: Optional[torch.dtype] = None,
306+
tp_size: Optional[int] = None,
307+
quant_config: Optional[QuantizationConfig] = None,
295308
):
296309
super().__init__()
297310
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
298-
self.num_total_experts = num_experts
299-
self.top_k = top_k
300-
self.hidden_size = hidden_size
301-
self.intermediate_size = intermediate_size // self.tp_size
311+
self.num_total_experts = config.num_experts
312+
self.top_k = config.num_experts_per_tok
313+
self.hidden_size = config.hidden_size
314+
self.intermediate_size = config.intermediate_size // self.tp_size
302315

303316
if params_dtype is None:
304317
params_dtype = torch.get_default_dtype()
305318
self.params_dtype = params_dtype
306319

307-
if self.num_total_experts > 1:
308-
# init expert router iff this layer has multiple experts
309-
self.router = ReplicatedLinear(
310-
self.hidden_size,
311-
self.num_total_experts,
312-
bias=False,
313-
params_dtype=self.params_dtype,
314-
)
320+
self.router = ReplicatedLinear(self.hidden_size,
321+
self.num_total_experts,
322+
bias=False,
323+
params_dtype=self.params_dtype)
315324

316325
self.ws = nn.Parameter(
317326
torch.empty(
@@ -366,14 +375,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366375
num_tokens, hidden_size = hidden_states.shape
367376
hidden_states = hidden_states.view(-1, self.hidden_size)
368377
# router_logits: (batch * sequence_length, n_experts)
369-
if self.num_total_experts > 1:
370-
router_logits, _ = self.router(hidden_states)
371-
else:
372-
router_logits = torch.ones(
373-
[hidden_states.shape[0], 1],
374-
device=hidden_states.device,
375-
dtype=hidden_states.dtype,
376-
)
378+
router_logits, _ = self.router(hidden_states)
377379

378380
final_hidden_states = fused_moe(
379381
hidden_states,
@@ -394,28 +396,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394396

395397

396398
class JambaMambaDecoderLayer(nn.Module):
397-
398399
def __init__(
399-
self,
400-
config: JambaConfig,
401-
actual_num_experts: int,
402-
actual_num_experts_per_tok: int,
403-
layer_idx: int,
400+
self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None
404401
) -> None:
405402
super().__init__()
406403
self.layer_idx = layer_idx
407404
self.config = config
408405
self.mamba = JambaMambaMixer(config, layer_idx)
409-
self.moe = JambaMoE(
410-
num_experts=actual_num_experts,
411-
top_k=actual_num_experts_per_tok,
412-
hidden_size=config.hidden_size,
413-
intermediate_size=config.intermediate_size,
414-
)
415-
self.input_layernorm = RMSNorm(config.hidden_size,
416-
eps=config.rms_norm_eps)
417-
self.pre_moe_layernorm = RMSNorm(config.hidden_size,
418-
eps=config.rms_norm_eps)
406+
407+
num_experts = config.layers_num_experts[layer_idx]
408+
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
409+
self.feed_forward = ffn_layer_class(config, quant_config)
410+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
411+
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419412

420413
def forward(
421414
self,
@@ -436,20 +429,15 @@ def forward(
436429
hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
437430
ssm_state)
438431
# Fully Connected
439-
hidden_states, residual = self.pre_moe_layernorm(
440-
hidden_states, residual)
441-
hidden_states = self.moe(hidden_states)
432+
hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual)
433+
hidden_states = self.feed_forward(hidden_states)
442434
return hidden_states, residual
443435

444436

445437
class JambaAttentionDecoderLayer(nn.Module):
446438

447439
def __init__(
448-
self,
449-
config: JambaConfig,
450-
actual_num_experts: int,
451-
actual_num_experts_per_tok: int,
452-
quant_config: Optional[QuantizationConfig] = None,
440+
self, config: JambaConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None,
453441
) -> None:
454442
super().__init__()
455443
self.hidden_size = config.hidden_size
@@ -494,16 +482,11 @@ def __init__(
494482
sliding_window=self.sliding_window,
495483
)
496484

497-
self.moe = JambaMoE(
498-
num_experts=actual_num_experts,
499-
top_k=actual_num_experts_per_tok,
500-
hidden_size=config.hidden_size,
501-
intermediate_size=config.intermediate_size,
502-
)
503-
self.input_layernorm = RMSNorm(config.hidden_size,
504-
eps=config.rms_norm_eps)
505-
self.pre_moe_layernorm = RMSNorm(config.hidden_size,
506-
eps=config.rms_norm_eps)
485+
num_experts = config.layers_num_experts[layer_idx]
486+
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
487+
self.feed_forward = ffn_layer_class(config, quant_config)
488+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
489+
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
507490

508491
def self_attention(
509492
self,
@@ -542,12 +525,14 @@ def forward(
542525
attn_metadata=attn_metadata,
543526
)
544527
# Fully Connected
545-
hidden_states, residual = self.pre_moe_layernorm(
546-
hidden_states, residual)
547-
hidden_states = self.moe(hidden_states)
528+
hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual)
529+
hidden_states = self.feed_forward(hidden_states)
548530
return hidden_states, residual
549531

550532

533+
ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
534+
535+
551536
class JambaModel(nn.Module):
552537

553538
def __init__(
@@ -570,40 +555,12 @@ def __init__(
570555
org_num_embeddings=config.vocab_size,
571556
)
572557

573-
# init each model layer, decide if it's mamba/attention and
574-
# has experts and pass it down
575-
576-
module_list = []
558+
decoder_layers = []
577559
for i in range(config.num_hidden_layers):
578-
is_attn = ((i - self.config.attn_layer_offset) %
579-
self.config.attn_layer_period == 0)
580-
is_expert = ((i - self.config.expert_layer_offset) %
581-
self.config.expert_layer_period == 0)
582-
583-
actual_num_experts = config.num_experts if is_expert else 1
584-
actual_num_experts_per_tok = config.num_experts_per_tok \
585-
if is_expert else 1
586-
587-
if is_attn:
588-
module_list.append(
589-
JambaAttentionDecoderLayer(
590-
config,
591-
actual_num_experts=actual_num_experts,
592-
actual_num_experts_per_tok=actual_num_experts_per_tok,
593-
quant_config=quant_config
594-
))
595-
else:
596-
module_list.append(
597-
JambaMambaDecoderLayer(
598-
config,
599-
actual_num_experts=actual_num_experts,
600-
actual_num_experts_per_tok=actual_num_experts_per_tok,
601-
layer_idx=i,
602-
))
603-
604-
self.layers = nn.ModuleList(module_list)
605-
self.final_layernorm = RMSNorm(config.hidden_size,
606-
eps=config.rms_norm_eps)
560+
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
561+
decoder_layers.append(layer_class(config, layer_idx=i, quant_config=quant_config))
562+
self.layers = nn.ModuleList(decoder_layers)
563+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
607564

608565
def forward(
609566
self,
@@ -732,6 +689,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
732689
("qkv_proj", "q_proj", "q"),
733690
("qkv_proj", "k_proj", "k"),
734691
("qkv_proj", "v_proj", "v"),
692+
("gate_up_proj", "gate_proj", 0),
693+
("gate_up_proj", "up_proj", 1),
735694
]
736695

737696
expert_params_mapping = [
@@ -758,6 +717,8 @@ def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]]):
758717
for param_name, weight_name, shard_id in stacked_params_mapping:
759718
if weight_name not in name:
760719
continue
720+
if 'experts' in name:
721+
continue
761722
name = name.replace(weight_name, param_name)
762723
# Skip loading extra bias for GPTQ models.
763724
if name.endswith(".bias") and name not in params_dict:

vllm/transformers_utils/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from vllm.transformers_utils.configs.jais import JAISConfig
88
from vllm.transformers_utils.configs.mpt import MPTConfig
99

10+
from vllm.transformers_utils.configs.jamba import JambaConfig
11+
1012
__all__ = [
1113
"ChatGLMConfig", "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig"
1214
]

0 commit comments

Comments
 (0)