Skip to content

Commit

Permalink
Fix Bark saving (#33266)
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe authored Sep 3, 2024
1 parent 7ed9789 commit 979f477
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,17 @@ def resize_token_embeddings(

return model_embeds

def _tie_weights(self):
if getattr(self.config, "tie_word_embeddings", True):
self._tied_weights_keys = []
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()

for i in range(self.config.n_codes_total - self.config.n_codes_given):
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
self._tied_weights_keys.append(f"lm_heads.{i}.weight")

def tie_weights(self):
"""
Tie the weights between the input embeddings list and the output embeddings list.
Expand Down

0 comments on commit 979f477

Please sign in to comment.