Skip to content

Commit 2edb0e4

Browse files
authored
[mllama] fix loading and inference (#38223)
fix loading
1 parent 390f153 commit 2edb0e4

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,6 @@ def forward(
486486
value_states = self.v_proj(cross_attention_states)
487487
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
488488
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
489-
key_states = repeat_kv(key_states, self.num_key_value_groups)
490-
value_states = repeat_kv(value_states, self.num_key_value_groups)
491489

492490
key_states = self.k_norm(key_states)
493491
if past_key_value is not None:
@@ -850,7 +848,7 @@ def forward(self, x, position_ids):
850848
@auto_docstring
851849
class MllamaPreTrainedModel(PreTrainedModel):
852850
config_class = MllamaConfig
853-
base_model_prefix = "model"
851+
base_model_prefix = ""
854852
supports_gradient_checkpointing = True
855853
_no_split_modules = [
856854
"MllamaVisionEncoderLayer",

0 commit comments

Comments
 (0)