Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add Mamba] Adds support for the Mamba models #28094

Merged
merged 123 commits into from
Mar 5, 2024
Merged
Changes from 1 commit
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
81c642f
initial-commit
ArthurZucker Dec 16, 2023
c50602b
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Jan 31, 2024
00d3a6c
start cleaning
ArthurZucker Jan 31, 2024
921bb24
small nits
ArthurZucker Feb 1, 2024
b3f216d
small nits
ArthurZucker Feb 3, 2024
7235b57
current updates
ArthurZucker Feb 3, 2024
7a407a7
add kernels
ArthurZucker Feb 5, 2024
9f2a982
small refactoring little step
ArthurZucker Feb 5, 2024
04c991a
add comments
ArthurZucker Feb 5, 2024
aa7e8d2
styling
ArthurZucker Feb 5, 2024
26748c4
nit
ArthurZucker Feb 5, 2024
75e376a
nits
ArthurZucker Feb 14, 2024
1c104b5
Style
ArthurZucker Feb 14, 2024
0e90dae
Merge
ArthurZucker Feb 14, 2024
a804466
Small changes
ArthurZucker Feb 14, 2024
6b87ad2
Push dummy mambda simple slow
ArthurZucker Feb 14, 2024
a7ec8d6
nit
ArthurZucker Feb 14, 2024
5046451
Use original names
ArthurZucker Feb 14, 2024
b5831e3
Use original names and remove norm
ArthurZucker Feb 15, 2024
e9a80ad
Updates for inference params
ArthurZucker Feb 15, 2024
ee4a7ef
Style nd updates
ArthurZucker Feb 15, 2024
d8c195f
nits
ArthurZucker Feb 15, 2024
e64fedc
Match logits
ArthurZucker Feb 16, 2024
aee558f
Add a test
ArthurZucker Feb 16, 2024
eae5f45
Add expected generated text
ArthurZucker Feb 16, 2024
1f8e8d0
nits doc, imports and styling
ArthurZucker Feb 16, 2024
3cc06e5
style
ArthurZucker Feb 16, 2024
5a5324c
oups
ArthurZucker Feb 16, 2024
325b66b
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 16, 2024
81303f4
dont install kernels, invite users to install the required kernels
ArthurZucker Feb 19, 2024
1a10310
let use use the original packages
ArthurZucker Feb 19, 2024
89fb490
styling
ArthurZucker Feb 19, 2024
6cfe216
nits
ArthurZucker Feb 19, 2024
1ecbd22
fix some copieds
ArthurZucker Feb 19, 2024
b937122
update doc
ArthurZucker Feb 19, 2024
9752dd0
fix-copies
ArthurZucker Feb 19, 2024
a7881a3
styling done
ArthurZucker Feb 19, 2024
f445b0d
nits
ArthurZucker Feb 19, 2024
64ec8dd
fix import check
ArthurZucker Feb 19, 2024
e6e3ba8
run but wrong cuda ress
ArthurZucker Feb 19, 2024
ed4eb4c
mamba CUDA works :)
ArthurZucker Feb 19, 2024
4c8fc48
fix the fast path
ArthurZucker Feb 19, 2024
69e103f
config naming nits
ArthurZucker Feb 19, 2024
ba21ff2
conversion script is not required at this stage
ArthurZucker Feb 19, 2024
fe53728
finish fixing the fast path: generation make sense now!
ArthurZucker Feb 19, 2024
9411169
nit
ArthurZucker Feb 19, 2024
c2c7709
Let's start working on the CIs
ArthurZucker Feb 19, 2024
1e73ca9
style
ArthurZucker Feb 19, 2024
834f46f
git push Merge branch 'main' of github.com:huggingface/transformers i…
ArthurZucker Feb 19, 2024
a1a94f3
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 20, 2024
2213222
better style
ArthurZucker Feb 20, 2024
2a02006
more nits
ArthurZucker Feb 20, 2024
8b0412f
test nit
ArthurZucker Feb 20, 2024
fbd6a2c
quick fix for now
ArthurZucker Feb 20, 2024
823f11a
nits
ArthurZucker Feb 20, 2024
88896a9
nit
ArthurZucker Feb 20, 2024
7f72ee8
nit
ArthurZucker Feb 21, 2024
0555247
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 29, 2024
0072a6c
nit
ArthurZucker Feb 29, 2024
7f6c56f
nits
ArthurZucker Feb 29, 2024
f67c353
update test rest
ArthurZucker Feb 29, 2024
2ab5a86
fixup
ArthurZucker Feb 29, 2024
8920be3
update test
ArthurZucker Feb 29, 2024
87d0664
nit
ArthurZucker Feb 29, 2024
8b00d76
some fixes
ArthurZucker Feb 29, 2024
ca9835c
nits
ArthurZucker Feb 29, 2024
796ef3e
update test values
ArthurZucker Feb 29, 2024
170664a
fix styling
ArthurZucker Feb 29, 2024
92493a0
nit
ArthurZucker Feb 29, 2024
854ebad
support peft
ArthurZucker Feb 29, 2024
3bbd1b1
Merge branch 'main' of github.com:huggingface/transformers into add-m…
ArthurZucker Feb 29, 2024
aa0e6bb
integrations tests require torchg
ArthurZucker Feb 29, 2024
3c1537e
also add slow markers
ArthurZucker Feb 29, 2024
d06421a
styling
ArthurZucker Feb 29, 2024
5fb8062
chose forward wisely
ArthurZucker Feb 29, 2024
edb4e91
nits
ArthurZucker Feb 29, 2024
eb1fb64
update tests
ArthurZucker Feb 29, 2024
de4fe46
fix gradient checkpointing
ArthurZucker Feb 29, 2024
54ffaa3
fixup
ArthurZucker Feb 29, 2024
977d34f
nit
ArthurZucker Feb 29, 2024
0928453
fix doc
ArthurZucker Feb 29, 2024
2c90536
check copies
ArthurZucker Feb 29, 2024
4ba9c79
fix the docstring
ArthurZucker Feb 29, 2024
3651dba
fix some more tests
ArthurZucker Feb 29, 2024
426e6f3
style
ArthurZucker Feb 29, 2024
951b1aa
fix beam search
ArthurZucker Mar 1, 2024
4101369
add init schene
ArthurZucker Mar 1, 2024
65db96b
update
ArthurZucker Mar 1, 2024
0f3dfc7
nit
ArthurZucker Mar 1, 2024
f8bd0aa
fix
ArthurZucker Mar 1, 2024
b2bd0c7
fixup the doc
ArthurZucker Mar 1, 2024
cf58529
fix the doc
ArthurZucker Mar 1, 2024
e9c3447
fixup
ArthurZucker Mar 1, 2024
1282a75
tentative update but slow is no longer good
ArthurZucker Mar 1, 2024
fa561b2
nit
ArthurZucker Mar 1, 2024
91b8106
should we always use float32?
ArthurZucker Mar 1, 2024
e8142ca
nits
ArthurZucker Mar 1, 2024
623b636
revert wrong changes
ArthurZucker Mar 1, 2024
566c799
res in float32
ArthurZucker Mar 1, 2024
5d637d9
cleanup
ArthurZucker Mar 2, 2024
648a292
skip fmt for now
ArthurZucker Mar 2, 2024
e306e89
update generation values
ArthurZucker Mar 2, 2024
057d7a3
update test values running original model
ArthurZucker Mar 2, 2024
72f8936
fixup
ArthurZucker Mar 2, 2024
f415081
update tests + rename inference_params to cache_params + make sure tr…
ArthurZucker Mar 4, 2024
6bb659a
small nits
ArthurZucker Mar 4, 2024
178fe76
more nits
ArthurZucker Mar 4, 2024
3a46724
fix final CIs
ArthurZucker Mar 4, 2024
13204e0
style
ArthurZucker Mar 4, 2024
1608a90
nit doc
ArthurZucker Mar 4, 2024
99119ba
I hope final doc nits
ArthurZucker Mar 4, 2024
d6fb1ef
nit
ArthurZucker Mar 4, 2024
844530f
🫠
ArthurZucker Mar 4, 2024
52be018
final touch!
ArthurZucker Mar 4, 2024
d03de1c
fix torch import
ArthurZucker Mar 4, 2024
c0672a8
Apply suggestions from code review
ArthurZucker Mar 5, 2024
dfc1212
Apply suggestions from code review
ArthurZucker Mar 5, 2024
acd4ccf
fix fix and fix
ArthurZucker Mar 5, 2024
2ddd9aa
fix base model prefix!
ArthurZucker Mar 5, 2024
0c5d7ed
nit
ArthurZucker Mar 5, 2024
28e5ef0
Update src/transformers/models/mamba/__init__.py
ArthurZucker Mar 5, 2024
f963e38
Update docs/source/en/model_doc/mamba.md
ArthurZucker Mar 5, 2024
095dabd
nit
ArthurZucker Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
start cleaning
  • Loading branch information
ArthurZucker committed Jan 31, 2024
commit 00d3a6c1f0f35f9a0e0a88530fd3c5fd1f3f2db0
213 changes: 42 additions & 171 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@
mamba_cuda_kernel = None


# Copied from transformers.models.rwkv.modeling_rwkv.load_wkv_cuda_kernel with RWKV->MAMBA,rwkv->mamba
def load_wkv_cuda_kernel(context_length):
# Copied from transformers.models.mamba.modeling_mamba.load_mamba_cuda_kernel with mamba->MAMBA,mamba->mamba
def load_mamba_cuda_kernel(context_length):
from torch.utils.cpp_extension import load as load_kernel

global mamba_cuda_kernel

kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mamba"
cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
cuda_kernel_files = [kernel_folder / f for f in ["mamba_op.cpp", "mamba_cuda.cu", "mamba_cuda_bf16.cu"]]

# Only load the kernel if it's not been loaded yet or if we changed the context length
if mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == context_length:
Expand All @@ -79,16 +79,15 @@ def load_wkv_cuda_kernel(context_length):
f"-DTmax={context_length}",
]
mamba_cuda_kernel = load_kernel(
name=f"wkv_{context_length}",
name=f"mamba_{context_length}",
sources=cuda_kernel_files,
verbose=(logging.get_verbosity() == logging.DEBUG),
extra_cuda_cflags=flags,
)
mamba_cuda_kernel.max_seq_length = context_length


# Copied from transformers.models.rwkv.modeling_rwkv.RwkvLinearAttention with Rwkv->Mamba,rwkv->mamba
class MambaLinearAttention(torch.autograd.Function):
class MambaMixer(torch.autograd.Function):
@staticmethod
def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
batch_size, seq_len, hidden_size = key.size()
Expand All @@ -111,7 +110,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa
or key.device.type != "cuda"
or value.device.type != "cuda"
):
raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
raise ValueError("Calling the CUDA kernel for mamba attention requires all tensors to be on CUDA devices.")

time_decay = -torch.exp(time_decay.float().contiguous())
if key.dtype == torch.float16:
Expand Down Expand Up @@ -194,7 +193,6 @@ def backward(ctx, g_output, g_state=None):
)


# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention_cpu with rwkv->mamba
def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
# within a torch.no_grad.
Expand All @@ -217,7 +215,7 @@ def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, r
current_key = key[:, current_index].float()
current_value = value[:, current_index]

# wkv computation at time t
# mamba computation at time t
max_for_output = torch.maximum(max_state, current_key + time_first)
e1 = torch.exp(max_state - max_for_output)
e2 = torch.exp(current_key + time_first - max_for_output)
Expand All @@ -239,7 +237,6 @@ def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, r
return output, state


# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention with Rwkv->Mamba,rwkv->mamba
def mamba_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
Expand All @@ -251,131 +248,14 @@ def mamba_linear_attention(time_decay, time_first, key, value, state=None, retur
return MambaLinearAttention.apply(time_decay, time_first, key, value, state, return_state)


# Copied from transformers.models.rwkv.modeling_rwkv.RwkvSelfAttention with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba
class MambaSelfAttention(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.config = config
kernel_loaded = mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == config.context_length
if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
try:
load_wkv_cuda_kernel(config.context_length)
except Exception:
logger.info("Could not load the custom CUDA kernel for MAMBA attention.")
self.layer_id = layer_id
hidden_size = config.hidden_size
attention_hidden_size = (
config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
)
self.attention_hidden_size = attention_hidden_size

self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
self.time_first = nn.Parameter(torch.empty(attention_hidden_size))

self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)

# TODO: maybe jit, otherwise move inside forward
def extract_key_value(self, hidden, state=None):
# Mix hidden with the previous timestep to produce key, value, receptance
if hidden.size(1) == 1 and state is not None:
shifted = state[1][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[1][:, :, self.layer_id]
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)

key = self.key(key)
value = self.value(value)
receptance = torch.sigmoid(self.receptance(receptance))
if state is not None:
state[1][:, :, self.layer_id] = hidden[:, -1]
return receptance, key, value, state

def forward(self, hidden, state=None, use_cache=False):
receptance, key, value, state = self.extract_key_value(hidden, state=state)
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
mamba, layer_state = mamba_linear_attention(
self.time_decay,
self.time_first,
key,
value,
state=layer_state,
return_state=use_cache,
)

if layer_state is not None:
state[2][:, :, self.layer_id] = layer_state[0]
state[3][:, :, self.layer_id] = layer_state[1]
state[4][:, :, self.layer_id] = layer_state[2]

return self.output(receptance * mamba), state


# Copied from transformers.models.rwkv.modeling_rwkv.RwkvFeedForward with Rwkv->Mamba
class MambaFeedForward(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.config = config
self.layer_id = layer_id
hidden_size = config.hidden_size
intermediate_size = (
config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
)

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))

self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
self.value = nn.Linear(intermediate_size, hidden_size, bias=False)

def forward(self, hidden, state=None):
if hidden.size(1) == 1 and state is not None:
shifted = state[0][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[0][:, :, self.layer_id]
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)

key = torch.square(torch.relu(self.key(key)))
value = self.value(key)
receptance = torch.sigmoid(self.receptance(receptance))

if state is not None:
state[0][:, :, self.layer_id] = hidden[:, -1]

return receptance * value, state


# Copied from transformers.models.rwkv.modeling_rwkv.RwkvBlock with Rwkv->Mamba
class MambaBlock(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.config = config
self.layer_id = layer_id

if layer_id == 0:
self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.attention = MambaSelfAttention(config, layer_id)
self.feed_forward = MambaFeedForward(config, layer_id)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)

def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
if self.layer_id == 0:
Expand All @@ -396,7 +276,7 @@ def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
return outputs


# Copied from transformers.models.rwkv.modeling_rwkv.RwkvPreTrainedModel with Rwkv->Mamba,rwkv->mamba
# Copied from transformers.models.mamba.modeling_mamba.mambaPreTrainedModel with mamba->Mamba,mamba->mamba
class MambaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand All @@ -411,7 +291,7 @@ class MambaPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, MambaSelfAttention):
if isinstance(module, MambaMixer):
layer_id = module.layer_id
num_hidden_layers = module.config.num_hidden_layers
hidden_size = module.config.hidden_size
Expand Down Expand Up @@ -448,27 +328,32 @@ def _init_weights(self, module):
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
elif isinstance(module, MambaFeedForward):
layer_id = module.layer_id
num_hidden_layers = module.config.num_hidden_layers
hidden_size = module.config.hidden_size

ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0

time_weight = torch.tensor(
[i / hidden_size for i in range(hidden_size)],
dtype=module.time_mix_key.dtype,
device=module.time_mix_key.device,
)
time_weight = time_weight[None, None, :]

with torch.no_grad():
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)

if self.config.rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(self.config.n_residuals_per_layer * self.config.n_layer)


@dataclass
# Copied from transformers.models.rwkv.modeling_rwkv.RwkvOutput with RWKV->MAMBA,Rwkv->Mamba
class MambaOutput(ModelOutput):
"""
Class for the MAMBA model outputs.
Expand Down Expand Up @@ -499,7 +384,6 @@ class MambaOutput(ModelOutput):


@dataclass
# Copied from transformers.models.rwkv.modeling_rwkv.RwkvCausalLMOutput with Rwkv->Mamba
class MambaCausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Expand Down Expand Up @@ -595,14 +479,13 @@ class MambaCausalLMOutput(ModelOutput):
"The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.",
MAMBA_START_DOCSTRING,
)
# Copied from transformers.models.rwkv.modeling_rwkv.RwkvModel with RWKV->MAMBA,Rwkv->Mamba
class MambaModel(MambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.blocks = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
self.ln_out = nn.LayerNorm(config.hidden_size)
self.layers = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
self.norm_f = nn.LayerNorm(config.hidden_size)

self.layers_are_rescaled = False

Expand Down Expand Up @@ -654,13 +537,8 @@ def forward(

if use_cache and state is None:
shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
state = [
torch.zeros(
*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
)
for i in range(5)
]
state[4] -= 1e30
dtype = inputs_embeds.dtype if i <= 1 else torch.float32
state = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)]

if self.gradient_checkpointing and self.training:
if use_cache:
Expand All @@ -673,23 +551,16 @@ def forward(

all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.blocks):
for idx, layer in enumerate(self.layers):
if self.gradient_checkpointing and self.training:
hidden_states, state, attentions = self._gradient_checkpointing_func(
block.__call__, hidden_states, state, use_cache, output_attentions
layer.__call__, hidden_states, state, use_cache, output_attentions
)
else:
hidden_states, state, attentions = block(
hidden_states, state, attentions = layer(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)

if (
self.layers_are_rescaled
and self.config.rescale_every > 0
and (idx + 1) % self.config.rescale_every == 0
):
hidden_states = hidden_states / 2

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand Down Expand Up @@ -764,7 +635,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
""",
MAMBA_START_DOCSTRING,
)
# Copied from transformers.models.rwkv.modeling_rwkv.RwkvForCausalLM with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba
# Copied from transformers.models.mamba.modeling_mamba.mambaForCausalLM with mamba->MAMBA,mamba->Mamba,mamba->mamba
class MambaForCausalLM(MambaPreTrainedModel):
_tied_weights_keys = ["head.weight"]

Expand Down
Loading