Skip to content

Commit

Permalink
remove the LoRA hack for the mamba dt_proj bias. It was solved in hug…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 committed Mar 31, 2024
1 parent ce8b476 commit 810dfbf
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,18 +940,10 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: MambaC
# in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
# linear layers, and requires to call the forward pass directly.
# The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
if hasattr(self.dt_proj, "base_layer"):
# In case of LoRA, we need to access the base layer to get the bias
time_proj_bias = self.dt_proj.base_layer.bias
self.dt_proj.base_layer.bias = None
else:
time_proj_bias = self.dt_proj.bias
self.dt_proj.bias = None
time_proj_bias = self.dt_proj.bias
self.dt_proj.bias = None
discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
if hasattr(self.dt_proj, "base_layer"):
self.dt_proj.base_layer.bias = time_proj_bias
else:
self.dt_proj.bias = time_proj_bias
self.dt_proj.bias = time_proj_bias

A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
Expand Down

0 comments on commit 810dfbf

Please sign in to comment.