Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add Mamba] Adds support for the Mamba models #28094

Merged
merged 123 commits into from
Mar 5, 2024
Merged
Changes from 1 commit
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
81c642f
initial-commit
ArthurZucker Dec 16, 2023
c50602b
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Jan 31, 2024
00d3a6c
start cleaning
ArthurZucker Jan 31, 2024
921bb24
small nits
ArthurZucker Feb 1, 2024
b3f216d
small nits
ArthurZucker Feb 3, 2024
7235b57
current updates
ArthurZucker Feb 3, 2024
7a407a7
add kernels
ArthurZucker Feb 5, 2024
9f2a982
small refactoring little step
ArthurZucker Feb 5, 2024
04c991a
add comments
ArthurZucker Feb 5, 2024
aa7e8d2
styling
ArthurZucker Feb 5, 2024
26748c4
nit
ArthurZucker Feb 5, 2024
75e376a
nits
ArthurZucker Feb 14, 2024
1c104b5
Style
ArthurZucker Feb 14, 2024
0e90dae
Merge
ArthurZucker Feb 14, 2024
a804466
Small changes
ArthurZucker Feb 14, 2024
6b87ad2
Push dummy mambda simple slow
ArthurZucker Feb 14, 2024
a7ec8d6
nit
ArthurZucker Feb 14, 2024
5046451
Use original names
ArthurZucker Feb 14, 2024
b5831e3
Use original names and remove norm
ArthurZucker Feb 15, 2024
e9a80ad
Updates for inference params
ArthurZucker Feb 15, 2024
ee4a7ef
Style nd updates
ArthurZucker Feb 15, 2024
d8c195f
nits
ArthurZucker Feb 15, 2024
e64fedc
Match logits
ArthurZucker Feb 16, 2024
aee558f
Add a test
ArthurZucker Feb 16, 2024
eae5f45
Add expected generated text
ArthurZucker Feb 16, 2024
1f8e8d0
nits doc, imports and styling
ArthurZucker Feb 16, 2024
3cc06e5
style
ArthurZucker Feb 16, 2024
5a5324c
oups
ArthurZucker Feb 16, 2024
325b66b
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 16, 2024
81303f4
dont install kernels, invite users to install the required kernels
ArthurZucker Feb 19, 2024
1a10310
let use use the original packages
ArthurZucker Feb 19, 2024
89fb490
styling
ArthurZucker Feb 19, 2024
6cfe216
nits
ArthurZucker Feb 19, 2024
1ecbd22
fix some copieds
ArthurZucker Feb 19, 2024
b937122
update doc
ArthurZucker Feb 19, 2024
9752dd0
fix-copies
ArthurZucker Feb 19, 2024
a7881a3
styling done
ArthurZucker Feb 19, 2024
f445b0d
nits
ArthurZucker Feb 19, 2024
64ec8dd
fix import check
ArthurZucker Feb 19, 2024
e6e3ba8
run but wrong cuda ress
ArthurZucker Feb 19, 2024
ed4eb4c
mamba CUDA works :)
ArthurZucker Feb 19, 2024
4c8fc48
fix the fast path
ArthurZucker Feb 19, 2024
69e103f
config naming nits
ArthurZucker Feb 19, 2024
ba21ff2
conversion script is not required at this stage
ArthurZucker Feb 19, 2024
fe53728
finish fixing the fast path: generation make sense now!
ArthurZucker Feb 19, 2024
9411169
nit
ArthurZucker Feb 19, 2024
c2c7709
Let's start working on the CIs
ArthurZucker Feb 19, 2024
1e73ca9
style
ArthurZucker Feb 19, 2024
834f46f
git push Merge branch 'main' of github.com:huggingface/transformers i…
ArthurZucker Feb 19, 2024
a1a94f3
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 20, 2024
2213222
better style
ArthurZucker Feb 20, 2024
2a02006
more nits
ArthurZucker Feb 20, 2024
8b0412f
test nit
ArthurZucker Feb 20, 2024
fbd6a2c
quick fix for now
ArthurZucker Feb 20, 2024
823f11a
nits
ArthurZucker Feb 20, 2024
88896a9
nit
ArthurZucker Feb 20, 2024
7f72ee8
nit
ArthurZucker Feb 21, 2024
0555247
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 29, 2024
0072a6c
nit
ArthurZucker Feb 29, 2024
7f6c56f
nits
ArthurZucker Feb 29, 2024
f67c353
update test rest
ArthurZucker Feb 29, 2024
2ab5a86
fixup
ArthurZucker Feb 29, 2024
8920be3
update test
ArthurZucker Feb 29, 2024
87d0664
nit
ArthurZucker Feb 29, 2024
8b00d76
some fixes
ArthurZucker Feb 29, 2024
ca9835c
nits
ArthurZucker Feb 29, 2024
796ef3e
update test values
ArthurZucker Feb 29, 2024
170664a
fix styling
ArthurZucker Feb 29, 2024
92493a0
nit
ArthurZucker Feb 29, 2024
854ebad
support peft
ArthurZucker Feb 29, 2024
3bbd1b1
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 29, 2024
aa0e6bb
integrations tests require torchg
ArthurZucker Feb 29, 2024
3c1537e
also add slow markers
ArthurZucker Feb 29, 2024
d06421a
styling
ArthurZucker Feb 29, 2024
5fb8062
chose forward wisely
ArthurZucker Feb 29, 2024
edb4e91
nits
ArthurZucker Feb 29, 2024
eb1fb64
update tests
ArthurZucker Feb 29, 2024
de4fe46
fix gradient checkpointing
ArthurZucker Feb 29, 2024
54ffaa3
fixup
ArthurZucker Feb 29, 2024
977d34f
nit
ArthurZucker Feb 29, 2024
0928453
fix doc
ArthurZucker Feb 29, 2024
2c90536
check copies
ArthurZucker Feb 29, 2024
4ba9c79
fix the docstring
ArthurZucker Feb 29, 2024
3651dba
fix some more tests
ArthurZucker Feb 29, 2024
426e6f3
style
ArthurZucker Feb 29, 2024
951b1aa
fix beam search
ArthurZucker Mar 1, 2024
4101369
add init schene
ArthurZucker Mar 1, 2024
65db96b
update
ArthurZucker Mar 1, 2024
0f3dfc7
nit
ArthurZucker Mar 1, 2024
f8bd0aa
fix
ArthurZucker Mar 1, 2024
b2bd0c7
fixup the doc
ArthurZucker Mar 1, 2024
cf58529
fix the doc
ArthurZucker Mar 1, 2024
e9c3447
fixup
ArthurZucker Mar 1, 2024
1282a75
tentative update but slow is no longer good
ArthurZucker Mar 1, 2024
fa561b2
nit
ArthurZucker Mar 1, 2024
91b8106
should we always use float32?
ArthurZucker Mar 1, 2024
e8142ca
nits
ArthurZucker Mar 1, 2024
623b636
revert wrong changes
ArthurZucker Mar 1, 2024
566c799
res in float32
ArthurZucker Mar 1, 2024
5d637d9
cleanup
ArthurZucker Mar 2, 2024
648a292
skip fmt for now
ArthurZucker Mar 2, 2024
e306e89
update generation values
ArthurZucker Mar 2, 2024
057d7a3
update test values running original model
ArthurZucker Mar 2, 2024
72f8936
fixup
ArthurZucker Mar 2, 2024
f415081
update tests + rename inference_params to cache_params + make sure tr…
ArthurZucker Mar 4, 2024
6bb659a
small nits
ArthurZucker Mar 4, 2024
178fe76
more nits
ArthurZucker Mar 4, 2024
3a46724
fix final CIs
ArthurZucker Mar 4, 2024
13204e0
style
ArthurZucker Mar 4, 2024
1608a90
nit doc
ArthurZucker Mar 4, 2024
99119ba
I hope final doc nits
ArthurZucker Mar 4, 2024
d6fb1ef
nit
ArthurZucker Mar 4, 2024
844530f
🫠
ArthurZucker Mar 4, 2024
52be018
final touch!
ArthurZucker Mar 4, 2024
d03de1c
fix torch import
ArthurZucker Mar 4, 2024
c0672a8
Apply suggestions from code review
ArthurZucker Mar 5, 2024
dfc1212
Apply suggestions from code review
ArthurZucker Mar 5, 2024
acd4ccf
fix fix and fix
ArthurZucker Mar 5, 2024
2ddd9aa
fix base model prefix!
ArthurZucker Mar 5, 2024
0c5d7ed
nit
ArthurZucker Mar 5, 2024
28e5ef0
Update src/transformers/models/mamba/__init__.py
ArthurZucker Mar 5, 2024
f963e38
Update docs/source/en/model_doc/mamba.md
ArthurZucker Mar 5, 2024
095dabd
nit
ArthurZucker Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Push dummy mambda simple slow
  • Loading branch information
ArthurZucker committed Feb 14, 2024
commit 6b87ad2c106325b178512035e4ce4680fe2e549c
161 changes: 75 additions & 86 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(self, config, layer_idx):
# self.use_fast_path = config.use_fast_path
self.layer_idx = layer_idx

self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias)


self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
Expand All @@ -272,14 +272,16 @@ def __init__(self, config, layer_idx):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

# projection of the input hidden states
self.input_projection = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False)
self.time_step_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True)
self.discrete_projection = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False)
# time step projection (discretization)
self.time_step_projection = nn.Linear(self.time_step_rank, self.d_inner, bias=True)
# S4D real initialization. These are not discretized!
# THe core is to load them, compute the discrete states, then write the updates state.
# Keeps the memory bounded
what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32)
A = what_is_this.repeat(self.d_inner).contiguous()
A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None,:].expand(self.d_inner, -1).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
Expand All @@ -295,8 +297,8 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None):
Returns: same shape as hidden_states
"""
_, seqlen, _ = hidden_states.shape
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
return None
# conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]

projected_states = self.in_proj(hidden_states).transpose(1,2)
hidden_states, z = projected_states.chunk(2, dim=1)

Expand Down Expand Up @@ -355,86 +357,74 @@ def update_ssm_state(self, ssm_state):

class MambaSlowMixer(MambaMixer):

def forward(self, hidden_states, inference_params=MambaCache()):
def forward(self, hidden_states, inference_params=None):
"""

Compute ∆ A B C D, the state space parameters.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)

Args:
hidden_states:
inference_params:

Returns:

"""
batch_size, seq_len, _ = hidden_states.shape

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1,2)
projected_states = self.input_projection(hidden_states).transpose(1,2)
hidden_states, gate = projected_states.chunk(2, dim=1)

# 2. Convolution sequence transformation
if inference_params is not None:
conv_state = inference_params.update_conv_states(hidden_states)

# TODO replace with simple conv call
# conv_state.copy_(self.conv1d(hidden_states)[..., :seq_len])

hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
# hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
# if self.conv1d.bias is not None:
# hidden_states = hidden_states + self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = dt.transpose(0,1)
x_dbl = self.discrete_projection(hidden_states.transpose(1,2))
time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1)
discrete_time_step = self.time_step_projection(time_step)

# discrete_time_step = discrete_time_step.transpose(0,1)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)

B = B.permute(0,2,1).contiguous()
C = C.permute(0,2,1).contiguous()

# 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N)
dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))

# TODO replace einsums
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)

discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1,2)
# [batch_size, d, l, 1] X [1, d, 1, n] -> [batch_size, d, l, n]
dA = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :])
# [batch_size, d, l, 1] [b, d, l, 1] -> [batch_size, d, l, 1] X [batch_size, 1, l, n] -> [batch_size, d, l, n]
deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :]

ssm_state = torch.zeros((batch_size, self.d_inner, self.d_state), device=A.device)
# ssm_state = inference_params.ssm_state
# 3.c perform the recurrence y ← SSM(A, B, C)(x)

ys = []
for i in range(seq_len):
self.ssm_state.copy_(self.ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB)
y = torch.einsum(self.ssm_state, C[:, i, :], 'b d_in n, b n -> b d_in')
ys.append(y)
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :])
# [b, d, n] X [b, n] -> [b, d]
y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1))
ys.append(y[:,:,0])
y = torch.stack(ys, dim=1) # shape (b, l, d)

y = y + self.D.to(hidden_states.dtype) * hidden_states
y = y * self.act(gate) # (B D)
y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]).transpose(1,2)
y = y * self.act(gate).transpose(1,2) # (B D)

# 4. Final linear projection
attn_outputs = self.out_proj(y)
return attn_outputs, conv_state, y

return attn_outputs, None, ssm_state, y
return attn_outputs, conv_state, ssm_state, y

_xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
conv_out = causal_conv1d_fn(
x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
bias=self.conv1d.bias,
activation=self.activation
)
conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros(
(bsz, seqlen, dim),
device=hidden_states.device,
dtype=hidden_states.dtype
)
for i in range(0, bsz):
x = conv_out[i:i+1,:,-1]
z = _z[i:i+1, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
output_tensor[i] = out


class MambaBlock(nn.Module):
Expand All @@ -444,17 +434,17 @@ def __init__(self, config, layer_idx):
self.layer_idx = layer_idx
# self.residual_in_fp32 = config.residual_in_fp32
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx = layer_idx)
self.mixer = MambaSlowMixer(config, layer_idx = layer_idx)

def forward(self, hidden_states, residual=None, inference_params=None):
residual = (hidden_states + residual) if residual is not None else hidden_states
def forward(self, hidden_states, inference_params=None):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, inference_params=inference_params)
outputs = (hidden_states, residual)
return outputs
hidden_states, con_states, ssm_state, y = self.mixer(hidden_states, inference_params=inference_params)
hidden_states = residual + hidden_states
return hidden_states


class MambaPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -605,7 +595,7 @@ def __init__(self, config):

self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.norm_f = nn.LayerNorm(config.hidden_size) # ir use ALL_LAYER_NORM[config.hidden_states]
self.norm_f = nn.LayerNorm(config.hidden_size)

self.layers_are_rescaled = False
self.gradient_checkpointing = False
Expand All @@ -628,7 +618,7 @@ def set_input_embeddings(self, new_embeddings):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
inference_params: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand All @@ -642,14 +632,14 @@ def forward(
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict


if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)

# TODO better to call _set_cache
if use_cache and inference_params is None:
shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
Expand All @@ -665,37 +655,36 @@ def forward(

hidden_states = inputs_embeds

all_self_attentions = () if output_attentions else None

all_last_states = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, layer in enumerate(self.layers):
ssm_state = None
if self.gradient_checkpointing and self.training:
hidden_states, cache, partial_states = self._gradient_checkpointing_func(
layer.__call__, hidden_states, state, use_cache, output_attentions
)
hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params)
else:
hidden_states, cache, partial_states = layer(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)
hidden_states = layer(hidden_states, inference_params=inference_params)
# inference_params.conv_state_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if output_attentions:
all_self_attentions = all_self_attentions + (partial_states,)
all_self_attentions = all_last_states + (ssm_state,)

hidden_states = self.ln_out(hidden_states)
hidden_states = self.norm_f(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(hidden_states for hidden_states in [hidden_states, state, all_hidden_states, all_self_attentions] if hidden_states is not None)
return tuple(hidden_states for hidden_states in [hidden_states, inference_params, all_hidden_states, all_last_states] if hidden_states is not None)

return MambaOutput(
last_hidden_state=hidden_states,
state=cache,
inference_params=inference_params,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
attentions=all_last_states,
)


Expand Down Expand Up @@ -730,9 +719,9 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)

def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None,attention_mask=None,**kwargs):
def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs):
# only last token for inputs_ids if the state is passed along.
if state is not None:
if inference_params is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
Expand Down Expand Up @@ -768,7 +757,7 @@ def forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

mamba_outputs = self.mamba(
mamba_outputs = self.backbone(
input_ids,
inference_params=inference_params,
inputs_embeds=inputs_embeds,
Expand All @@ -778,7 +767,7 @@ def forward(
)
hidden_states = mamba_outputs[0]

logits = self.head(hidden_states)
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
Expand All @@ -798,7 +787,7 @@ def forward(
return MambaCausalLMOutput(
loss=loss,
logits=logits,
cache=mamba_outputs.cache,
inference_params=mamba_outputs.inference_params,
hidden_states=mamba_outputs.hidden_states,
attentions=mamba_outputs.attentions,
)
Loading