Skip to content

Commit

Permalink
fix gemma embedding scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 24, 2024
1 parent c25d9c0 commit c70c916
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mlx_vlm/models/paligemma/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def __call__(
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)

else:
h = inputs_embeds

h = h * (self.args.hidden_size**0.5)

if cache is not None:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
Expand Down

0 comments on commit c70c916

Please sign in to comment.