Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bamba Model #10909

Merged
merged 78 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
62181d5
initial pr without tp fix
fabianlim Dec 5, 2024
51bc78c
fix casting in rms norm gated
fabianlim Dec 5, 2024
81b93b4
TP fix
fabianlim Dec 5, 2024
0f93e4a
fix mamba scan invalid address
fabianlim Dec 8, 2024
742ae79
some fixes and remove unused kernels
fabianlim Dec 12, 2024
b2dc5ca
fmt + lint
fabianlim Dec 12, 2024
9ad9e20
more comments
fabianlim Dec 12, 2024
25bf381
initial fix for chunked prefill (incomplete)
fabianlim Dec 12, 2024
43ce07c
improve comments
fabianlim Dec 12, 2024
80f14b5
do not attach seq_idx to attn_metadata
fabianlim Dec 12, 2024
6b8ac49
activate initial states for chunked prefill
fabianlim Dec 12, 2024
d788db6
reuse softplus and remove triton2 remark
fabianlim Dec 13, 2024
400db27
add comment on weight loader and format
fabianlim Dec 13, 2024
bda8ea7
rename test_jamba to test_hybrid and got rid of test_bamba
fabianlim Dec 13, 2024
66078d6
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim Dec 16, 2024
a74de9f
update bamba to ishybrid and support pp
fabianlim Dec 16, 2024
b44caa7
lint
fabianlim Dec 16, 2024
8cf3644
add unit test for mamba ssd
fabianlim Dec 16, 2024
e375b40
fix lint
fabianlim Dec 16, 2024
dcbae7b
full chunked-prefill fix (sans unit tests)
fabianlim Dec 21, 2024
2597105
format and add cont batch unit tests (will need more cases)
fabianlim Dec 23, 2024
db5eea5
fix kernel tests and add more chunked prefill cases
fabianlim Dec 23, 2024
dfbcb16
bound adjustment
fabianlim Dec 23, 2024
7913009
bound adjustment
fabianlim Dec 26, 2024
9c5d045
lint errors
fabianlim Dec 26, 2024
6bc9dac
Add permalink correction from @tlrmchlsmth
fabianlim Jan 3, 2025
6d02e85
improved comment for segsum, add more sizes for test_mamba_chunk_scan…
fabianlim Jan 3, 2025
e5882f2
rename and comment functions, add more sizes for test_mamba_chunk_sca…
fabianlim Jan 3, 2025
6d6fa86
addressed comments on mamba_mixer2.py
fabianlim Jan 3, 2025
773dd80
replace with get_rope
fabianlim Jan 3, 2025
63f5340
rope scaling
fabianlim Jan 4, 2025
89e36d8
fixes
fabianlim Jan 6, 2025
7a4ae96
zero out ssm states
fabianlim Jan 7, 2025
a9e149c
fix tests (sans updating dev checkpoint)
fabianlim Jan 7, 2025
5c9f48d
not replacing dev model for now
fabianlim Jan 11, 2025
55647b1
update requirements
fabianlim Jan 13, 2025
2342bc0
remove extraneous comment
fabianlim Jan 14, 2025
011c141
update test
fabianlim Jan 14, 2025
503bc42
fix lint
fabianlim Jan 15, 2025
312cf1d
fix lint
fabianlim Jan 15, 2025
c1db743
fix requirements-test
fabianlim Jan 15, 2025
c956a30
Mamba2 changes from #10909
tlrmchlsmth Jan 16, 2025
17923ad
Get Mamba2 working!
tlrmchlsmth Jan 16, 2025
4183d45
Add integration test -- something is wrong!!
tlrmchlsmth Jan 17, 2025
5377644
format
tlrmchlsmth Jan 17, 2025
39f55d1
fixes
tlrmchlsmth Jan 17, 2025
dd31f19
update test registry, fixes
fabianlim Jan 16, 2025
e2e5aac
Fix for conv state shape and update placeholder_attn
tlrmchlsmth Jan 19, 2025
bc1b8af
back out placeholder_attn changes
tlrmchlsmth Jan 19, 2025
9db0dd5
make seq_idx to chunk indices more efficient
fabianlim Jan 20, 2025
cd89283
WIP debugging, restore local mamba and placeholder_attn changes
tlrmchlsmth Jan 20, 2025
9a838a3
Integration tests are now green
tlrmchlsmth Jan 20, 2025
be8318e
remove bamba-specific files
tlrmchlsmth Jan 20, 2025
f34d434
Merge branch 'main' into tms/mamba2
tlrmchlsmth Jan 27, 2025
a65e2cb
Handle grouping in Mixer2RMSNormGated
tlrmchlsmth Jan 30, 2025
0d4bb0f
debug cruft
tlrmchlsmth Jan 30, 2025
74f6088
Remove codestral integration test
tlrmchlsmth Jan 30, 2025
95583b8
Merge branch 'tms/mamba2' into bamba-pr
fabianlim Feb 1, 2025
b72389c
update mamba_cache
fabianlim Feb 1, 2025
10d75eb
remove changes to requirements
fabianlim Feb 1, 2025
5aea1e6
revert changes
fabianlim Feb 1, 2025
2ee8d07
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim Feb 1, 2025
043e006
fix lint
fabianlim Feb 1, 2025
7e4ce4f
fix lint
fabianlim Feb 1, 2025
8219480
more reverts
fabianlim Feb 1, 2025
2a154e1
remove unnecessary stuff
fabianlim Feb 3, 2025
b0536f7
add mixer2 gated norm TP test
fabianlim Feb 3, 2025
b2e7952
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim Feb 3, 2025
06c4e7f
add header
fabianlim Feb 3, 2025
851239a
fix lint
fabianlim Feb 3, 2025
6466c3c
Merge branch 'main' into bamba-pr
tlrmchlsmth Feb 3, 2025
64f6a4e
checkpoint renames
fabianlim Feb 4, 2025
266ce81
(debug) test_mamba_ssm_ssd.py
fabianlim Feb 4, 2025
965620d
[debug] make all run same shard_id
fabianlim Feb 4, 2025
4a846ab
[debug] disable test case
fabianlim Feb 4, 2025
da380b1
revert debugs and add @tlrmchlsmth fix!
fabianlim Feb 6, 2025
51d3762
Merge branch 'main' into bamba-pr
tlrmchlsmth Feb 6, 2025
eba332a
update mamba and jamba for MambaCache changes
tlrmchlsmth Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
TP fix
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Dec 12, 2024
commit 81b93b40933a9423a9c9acf0cca88e35fa457875
211 changes: 183 additions & 28 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,62 @@
mamba_chunk_scan_combined)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import (divide, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, sharded_weight_loader, LoaderFunction)


from typing import Tuple, Union, Optional
from typing import Tuple, Union, Optional, List
from vllm.model_executor.custom_op import CustomOp

# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
# also referenced https://github.com/vllm-project/vllm/pull/9292
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.tp_size = get_tensor_model_parallel_world_size()
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})

def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
pass
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))

if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = (global_sums / count)

else:
variance = x.pow(2).mean(-1, keepdim=True)

x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * x.to(input_dtype)

def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

if self.tp_size > 1:
return self.forward_native(x, gate)

from vllm import _custom_ops as ops

# cast gate to float32 before silu
# cast x and gate to float32 before silu
out = torch.empty_like(x)
y = x * nn.functional.silu(gate.to(torch.float32))
ops.rms_norm(
Expand All @@ -58,6 +84,57 @@ def forward_cuda(
)
return out

def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the extra (logical) groups to account for head shards"""

# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0

return tp_size - ngroups % tp_size

def mamba_v2_sharded_weight_loader(
shard_spec: List[int], tp_size: int, tp_rank: int,
) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections are
correctly sharded so that they can be split into x, B, C. It also ensures the
the all the groups corresponding to a head shard is placed together with it.
"""

def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
for full_dim, extra, ratio in shard_spec:
# - full dim is the expected size of the model
# - if extra > 0, this means there was some expansion

# - num of dims expected to be loaded
shard_size = full_dim // tp_size

# - compute where to take the loaded shard from
rank = tp_rank // ratio

# - should start from here (determined by rank)
loaded_skip = rank * shard_size # take these number dims from loaded
loaded_start_idx = loaded_boundary + loaded_skip

# - these many number dims to take from loaded_weight
take = min(shard_size, full_dim - extra - loaded_skip)

# - always shard on dim 0
param.data[
boundary:boundary+take,...
] = loaded_weight[
loaded_start_idx:loaded_start_idx+take
]

# move boundaries
boundary += shard_size
loaded_boundary += (full_dim - extra)

return loader

# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(CustomOp):
Expand All @@ -76,7 +153,6 @@ def __init__(self,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
Expand All @@ -87,7 +163,22 @@ def __init__(self,
activation="silu",
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.time_step_rank = time_step_rank

# For TP, the sharding plan is as follows:
# - for the conv modules, since
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
# we shard intermediate_size and n_groups
# - since intermediate_size = n_heads * head_dim, sharding on
# intermediate_size is achieved by sharding on n_heads.
# - so if world_size divides groups, then sharding
# (n_groups / world_size, n_heads / world_size)
# also maintains the invariant n_heads % n_groups == 0
# - HOWEVER< if world_size DOES NOT divide groups, then we need to allocate
# extra space in the shard, such that the WHOLE GROUP must be placed
# together with the HEAD SHARD.
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed that this doesn't seem to be used at all. Should it be removed, or is it to be used somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yea.. let me check and I can remove

self.activation = activation
Expand All @@ -96,8 +187,17 @@ def __init__(self,
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads

self.n_groups = n_groups
self.conv_dim = intermediate_size + 2 * n_groups * ssm_state_size
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self.n_groups = n_groups + extra_groups_for_head_shards(n_groups, self.tp_size)

self.conv_dim = (
intermediate_size + 2 * self.n_groups * ssm_state_size
)
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
Expand All @@ -116,22 +216,66 @@ def __init__(self,
bias=use_bias,
quant_config=quant_config)

# unlike mamba_mixer.py (v1), we do not TP the A matrix as it is
# already quite small.
# - same for dt_bias and D

def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
param.data.copy_(-torch.exp(loaded_weight.float()))
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups to the head shards
group_shard_settings = (
self.n_groups * self.ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
self.num_heads // n_groups, # ratio for mapping back to original group
)
intemediate_settings = (intermediate_size, 0, 1)
head_setings = (self.num_heads, 0, 1)

delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(self.conv1d.bias, {
"weight_loader": mamba_v2_sharded_weight_loader(
[
intemediate_settings, group_shard_settings, group_shard_settings,
],
self.tp_size, tp_rank,
)
})

delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(self.conv1d.weight, {
"weight_loader": mamba_v2_sharded_weight_loader(
[
intemediate_settings, group_shard_settings, group_shard_settings,
],
self.tp_size, tp_rank
)
})

delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(self.in_proj.weight, {
"weight_loader": mamba_v2_sharded_weight_loader(
[
intemediate_settings, # for gate
intemediate_settings, group_shard_settings, group_shard_settings,
head_setings, # for dt
],
self.tp_size, tp_rank
)
})

# - these are TPed by heads to reduce the size of the
# temporal shape
self.A = nn.Parameter(
torch.empty(
num_heads,
dtype=torch.float32,
divide(num_heads, self.tp_size), dtype=torch.float32,
))
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))

self.dt_bias = nn.Parameter(torch.ones(num_heads))
self.D = nn.Parameter(torch.ones(num_heads))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})

self.out_proj = RowParallelLinear(
intermediate_size,
Expand All @@ -141,7 +285,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
quant_config=quant_config)

self.norm = Mixer2RMSNormGated(
intermediate_size, eps=rms_norm_eps
intermediate_size // self.tp_size, eps=rms_norm_eps
)

def forward_native(self, hidden_states: torch.Tensor,
Expand Down Expand Up @@ -171,7 +315,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
projected_states, _ = self.in_proj(hidden_states)
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[self.intermediate_size, self.conv_dim, self.num_heads],
[
self.intermediate_size // self.tp_size,
self.conv_dim // self.tp_size,
self.num_heads // self.tp_size,
],
dim=-1,
)

Expand Down Expand Up @@ -212,7 +360,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
# - get hidden_states, B and C after depthwise convolution.
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
[
self.intermediate_size // self.tp_size,
groups_time_state_size // self.tp_size,
groups_time_state_size // self.tp_size,
],
dim=-1,
)

Expand All @@ -233,11 +385,11 @@ def forward_cuda(self, hidden_states: torch.Tensor,
# ]

scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, -1, self.head_dim),
hidden_states.view(1, seq_len, self.num_heads // self.tp_size, self.head_dim),
dt.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups, -1),
C.view(1, seq_len, self.n_groups, -1),
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
Expand All @@ -261,13 +413,14 @@ def forward_cuda(self, hidden_states: torch.Tensor,
else:

# NOTE: can be optimized?
n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(-1, self.n_groups, C.shape[1] // self.n_groups)
hidden_states_reshaped = hidden_states.view(-1, self.num_heads, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view(-1, self.num_heads // self.tp_size, self.head_dim)

# - the hidden is reshaped into number of current batches
# - in this case there is no more prefil, so the batches gen
Expand All @@ -290,7 +443,9 @@ def forward_cuda(self, hidden_states: torch.Tensor,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
)
hidden_states = hidden_states.view(-1, self.num_heads * self.head_dim)
hidden_states = hidden_states.view(
-1, (self.num_heads // self.tp_size) * self.head_dim
)

# # 4. gated MLP
hidden_states = self.norm(hidden_states, gate)
Expand Down
20 changes: 14 additions & 6 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand Down Expand Up @@ -83,7 +84,6 @@ def __init__(self,
conv_kernel_size = config.mamba_d_conv,
intermediate_size = config.mamba_expand *\
config.hidden_size,
time_step_rank = config.mamba_dt_rank,
use_conv_bias = config.mamba_conv_bias,
use_bias = config.mamba_proj_bias,
use_rms_norm=True,
Expand Down Expand Up @@ -459,20 +459,28 @@ def _get_mamba_cache_shape(

intermediate_size = self.config.mamba_expand * hidden_size

# if n_groups is not divisible by world_size, need to extend the shards to ensure
# all groups needed by a head is sharded along with it
n_groups = (
self.config.mamba_n_groups +
extra_groups_for_head_shards(self.config.mamba_n_groups, world_size)
)

# - heads and n_groups are TP-ed
conv_dim = (
intermediate_size +
2 * self.config.mamba_n_groups * self.config.mamba_d_state
2 * n_groups * self.config.mamba_d_state
)
conv_state_shape = (
conv_dim // world_size,
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)

# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
self.config.mamba_n_heads,
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
Expand Down