Skip to content

Commit

Permalink
Merge pull request #24 from Blaizzy/pc/quantise-irregular
Browse files Browse the repository at this point in the history
Add support for PaliGemma and Quant Siglip
  • Loading branch information
Blaizzy authored May 24, 2024
2 parents 9a2a72e + c70c916 commit 3c47b80
Show file tree
Hide file tree
Showing 19 changed files with 834 additions and 36 deletions.
26 changes: 16 additions & 10 deletions mlx_vlm/chat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def generate(
else:
tokenizer = processor.tokenizer

input_ids, pixel_values = prepare_inputs(image_processor, processor, image, prompt)
image_token_index = model.config.image_token_index
input_ids, pixel_values = prepare_inputs(
image_processor, processor, image, prompt, image_token_index
)
logits, cache = model(input_ids, pixel_values)
logits = logits[:, -1, :]
y, _ = sample(logits, temp, top_p)
Expand Down Expand Up @@ -89,8 +92,7 @@ def generate(
def chat(message, history, temperature, max_tokens):

chat = []

if message["files"]:
if len(message["files"]) >= 0:
chat.append(get_message_json(config["model_type"], message["text"]))
else:
raise Exception("Please upload an image. Text only chat is not supported.")
Expand All @@ -103,16 +105,20 @@ def chat(message, history, temperature, max_tokens):
)

elif "tokenizer" in processor.__dict__.keys():
messages = processor.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
if processor.tokenizer.chat_template:
messages = processor.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
else:
messages = message["text"]

response = ""
for chunk in generate(
model,
processor,
message["files"][0],
message["files"][-1],
messages,
image_processor,
temperature,
Expand All @@ -130,7 +136,7 @@ def chat(message, history, temperature, max_tokens):
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", render=False
),
gr.Slider(
minimum=128,
Expand Down
11 changes: 6 additions & 5 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ def main():
)

elif "tokenizer" in processor.__dict__.keys():
prompt = processor.tokenizer.apply_chat_template(
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)
if model.config.model_type != "paligemma":
prompt = processor.tokenizer.apply_chat_template(
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

else:
ValueError(
Expand Down
4 changes: 3 additions & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id

return inputs_embeds

def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
inputs=input_ids, cache=cache, inputs_embeds=input_embeddings
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/idefics2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __call__(
inputs: mx.array,
cache=None,
inputs_embeds=None,
mask: Optional[mx.array] = None,
):
# for passing merged input embeddings
if inputs_embeds is None:
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/llava/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __call__(
inputs: mx.array,
cache=None,
inputs_embeds=None,
mask: Optional[mx.array] = None,
):
out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache
Expand Down
5 changes: 4 additions & 1 deletion mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class ModelConfig:
text_config: TextConfig
vision_config: VisionConfig
model_type: str
ignore_index: int = -100
image_token_index: int = 32000
vision_feature_select_strategy: str = "default"
Expand Down Expand Up @@ -130,7 +131,9 @@ def _merge_input_ids_with_image_features(
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)

def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embddings
Expand Down
8 changes: 5 additions & 3 deletions mlx_vlm/models/llava/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import mlx.core as mx
import mlx.nn as nn
import numpy as np


@dataclass
Expand Down Expand Up @@ -172,8 +173,10 @@ def __call__(self, x: mx.array) -> mx.array:
cls_embeddings = mx.broadcast_to(
self.class_embedding, (batch_size, 1, embed_dim)
)
position_ids = mx.array(np.arange(self.num_positions)[None, :])

embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
embeddings += self.position_embedding.weight
embeddings += self.position_embedding(position_ids)
return embeddings


Expand Down Expand Up @@ -219,8 +222,7 @@ def __call__(
) -> mx.array:
return self.vision_model(x, output_hidden_states)

@staticmethod
def sanitize(weights):
def sanitize(self, weights):
sanitized_weights = {}
for k, v in weights.items():
if "position_ids" in k:
Expand Down
1 change: 1 addition & 0 deletions mlx_vlm/models/nanoLlava/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __call__(
inputs: mx.array,
cache=None,
inputs_embeds=None,
mask: Optional[mx.array] = None,
):
out, cache = self.model(inputs, cache, inputs_embeds=inputs_embeds)
return out, cache
Expand Down
8 changes: 7 additions & 1 deletion mlx_vlm/models/nanoLlava/nanoLlava.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,13 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)

def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
def __call__(
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: Optional[mx.array] = None,
cache=None,
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
inputs=input_ids, cache=cache, inputs_embeds=input_embeddings
Expand Down
5 changes: 3 additions & 2 deletions mlx_vlm/models/nanoLlava/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
import mlx.nn as nn
import numpy as np


@dataclass
Expand Down Expand Up @@ -206,9 +207,9 @@ def __call__(self, x: mx.array) -> mx.array:
batch_size = x.shape[0]
patch_embeddings = self.patch_embedding(x)
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)

self.position_ids = mx.array(np.arange(self.num_positions)[None, :])
embeddings = patch_embeddings
embeddings += self.position_embedding.weight
embeddings += self.position_embedding(self.position_ids)
return embeddings


Expand Down
8 changes: 8 additions & 0 deletions mlx_vlm/models/paligemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .paligemma import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading

0 comments on commit 3c47b80

Please sign in to comment.