Skip to content

Commit

Permalink
Add support for Pixtral-12B (#67)
Browse files Browse the repository at this point in the history
* add pixtral (working example)

* convert to mlx

* add prompt utils

* add kwargs to all models and formatting

* formatting

* fix phi3v processor

* refactor cache and mask

* fix pixel val loading

* formatting
  • Loading branch information
Blaizzy authored Sep 29, 2024
1 parent 50961f6 commit 1926065
Show file tree
Hide file tree
Showing 19 changed files with 842 additions and 58 deletions.
7 changes: 6 additions & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,12 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down
11 changes: 5 additions & 6 deletions mlx_vlm/models/idefics2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


@dataclass
class TextConfig:
Expand Down Expand Up @@ -62,7 +64,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -116,7 +118,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -153,10 +155,7 @@ def __call__(
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
11 changes: 5 additions & 6 deletions mlx_vlm/models/llava/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


@dataclass
class TextConfig:
Expand Down Expand Up @@ -78,7 +80,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -132,7 +134,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -166,10 +168,7 @@ def __call__(
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down
7 changes: 2 additions & 5 deletions mlx_vlm/models/llava_bunny/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache
from ..base import KVCache, create_attention_mask


@dataclass
Expand Down Expand Up @@ -174,10 +174,7 @@ def __call__(
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __call__(
pixel_values: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
**kwargs,
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
Expand Down
11 changes: 5 additions & 6 deletions mlx_vlm/models/llava_next/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


@dataclass
class TextConfig:
Expand Down Expand Up @@ -78,7 +80,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -132,7 +134,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -166,10 +168,7 @@ def __call__(
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/llava_next/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):

input_embddings = self.get_input_embeddings(input_ids, pixel_values)
Expand Down
11 changes: 5 additions & 6 deletions mlx_vlm/models/multi_modality/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


@dataclass
class TextConfig:
Expand Down Expand Up @@ -78,7 +80,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -132,7 +134,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -166,10 +168,7 @@ def __call__(
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
7 changes: 6 additions & 1 deletion mlx_vlm/models/multi_modality/multi_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,12 @@ def _merge_input_ids_with_image_features(
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):

input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
Expand Down
12 changes: 6 additions & 6 deletions mlx_vlm/models/paligemma/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import mlx.core as mx
import mlx.nn as nn

from ..base import KVCache, create_attention_mask


@dataclass
class TextConfig:
Expand Down Expand Up @@ -66,7 +68,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -120,7 +122,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -155,11 +157,9 @@ def __call__(
else:
h = inputs_embeds

h = h * (self.config.hidden_size**0.5)
h *= self.config.hidden_size**0.5

if cache is not None:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)
Expand Down
16 changes: 10 additions & 6 deletions mlx_vlm/models/phi3_v/phi3_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mlx.nn as nn
import numpy as np

from ..base import KVCache, create_attention_mask
from .language import LanguageModel, TextConfig
from .su_rope import Phi3SuScaledRotaryEmbedding
from .vision import VisionConfig, VisionModel
Expand Down Expand Up @@ -90,7 +91,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down Expand Up @@ -148,7 +149,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
Expand Down Expand Up @@ -179,16 +180,18 @@ def __call__(
):
h = self.embed_tokens(inputs)
p = np.argwhere(inputs < 0).tolist()

if pixel_values is not None:
h = self.vision_embed_tokens(pixel_values, h, image_sizes, p)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)

mask = create_attention_mask(h)

if cache is None:
cache = [None] * len(self.layers)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

return self.norm(h)


Expand All @@ -206,6 +209,7 @@ def __call__(
pixel_values=None,
mask=None,
cache=None,
**kwargs,
):
out = self.model(inputs, pixel_values, mask, cache)
return self.lm_head(out).astype(self.lm_head.weight.dtype)
Expand Down
8 changes: 8 additions & 0 deletions mlx_vlm/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .pixtral import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading

0 comments on commit 1926065

Please sign in to comment.