Skip to content

Commit

Permalink
working model
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 2, 2024
1 parent 07228b6 commit f6e4dc7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 80 deletions.
27 changes: 11 additions & 16 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __call__(
values = mx.concatenate([value_cache, values], axis=2)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
queries, keys, values, scale=self.scale
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
Expand All @@ -128,20 +128,19 @@ def __call__(
x: mx.array,
hidden_states: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
latents = self.input_latents_norm(x)
context = self.input_context_norm(hidden_states)

latents, cache = self.self_attn(latents, context, mask=mask, cache=cache)
latents, _ = self.self_attn(latents, context, mask=mask)

latents = x + latents
r = latents

latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = r + latents
return latents, cache
return latents


class Idefics2PerceiverResampler(nn.Module):
Expand All @@ -157,18 +156,15 @@ def __init__(self, config: ModelConfig):
]
self.norm = nn.RMSNorm(self.hidden_size, eps=config.text_config.rms_norm_eps)

def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache=None):

if cache is None:
cache = [None] * len(self.layers)
def __call__(self, x: mx.array, mask: Optional[mx.array] = None):

h = mx.expand_dims(self.latents, axis=0)
h = mx.repeat(h, x.shape[0], axis=0)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, x, mask=mask, cache=cache[e])
for layer in self.layers:
h = layer(h, x, mask=mask)

return self.norm(h), cache
return self.norm(h)


class MLP(nn.Module):
Expand Down Expand Up @@ -218,13 +214,14 @@ def get_input_embeddings(
return self.language_model(input_ids)

inputs_embeds = self.language_model.embed_tokens(input_ids)
*_, hidden_state = self.vision_model(
print(pixel_values[0].shape)
pooler_output, embeddings, hidden_state = self.vision_model(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)

image_features = hidden_state[-1].astype(pixel_values.dtype)

image_features, _ = self.connector(image_features, mask=None)
image_features = self.connector(image_features, mask=None)

final_inputs_embeds = self._prepare_inputs_for_multimodal(
image_features, inputs_embeds, input_ids
Expand All @@ -243,9 +240,7 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id
image_token_positions = mx.array(np.where(special_image_token_mask)[1])

# Advanced indexing to place reshaped image features at the corresponding positions
inputs_embeds[0, image_token_positions, :] = reshaped_image_hidden_states[
: len(image_token_positions)
]
inputs_embeds[0, image_token_positions, :] = reshaped_image_hidden_states

return inputs_embeds

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/idefics2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TextConfig:
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000.0
rope_traditional: bool = True
rope_traditional: bool = False
tie_word_embeddings: bool = False

@classmethod
Expand Down
103 changes: 40 additions & 63 deletions mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
import mlx.nn as nn
import numpy as np


@dataclass
Expand Down Expand Up @@ -79,59 +80,17 @@ def __init__(
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=True)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=True)

def __call__(self, queries, keys, values, mask=None):
queries = self.q_proj(queries)
keys = self.k_proj(keys)
values = self.v_proj(values)
def __call__(self, x: mx.array, mask=None):
B, L, _ = x.shape
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)

num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)


class MHA(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
bias: bool = False,
):
super().__init__()

if (dims % num_heads) != 0:
raise ValueError(
"The input feature dimensions should be divisible by the "
f"number of heads ({dims} % {num_heads}) != 0"
)

self.num_heads = num_heads
head_dim = dims // num_heads
self.scale = head_dim**-0.5

self.in_proj = nn.Linear(dims, dims * 3, bias=bias)
self.out_proj = nn.Linear(dims, dims, bias=bias)

def __call__(self, queries: mx.array, kv: mx.array, mask=None, cache=None):
B, L, D = queries.shape

qkv = self.in_proj(queries)
_, keys, values = mx.split(qkv, 3, axis=-1)

num_heads = self.num_heads
B, L, D = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
Expand Down Expand Up @@ -164,7 +123,7 @@ def __init__(self, config: VisionConfig):

def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
y = self.layer_norm1(x)
y = self.self_attn(y, y, y, mask)
y = self.self_attn(y, mask)
x = x + y
y = self.layer_norm2(x)
y = self.mlp(y)
Expand All @@ -190,32 +149,31 @@ def __init__(self, config: VisionConfig):
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
dilation=1,
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

def __call__(self, x: mx.array) -> mx.array:
batch_size, max_im_h, max_im_w, _ = x.shape
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
B, H, W, C = x.shape
patch_embeddings = self.patch_embedding(x)
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
H // self.patch_size,
W // self.patch_size,
)
position_ids = mx.zeros(
(batch_size, max_nb_patches_h * max_nb_patches_w)
).astype(mx.uint64)
position_ids = np.full((B, max_nb_patches_h * max_nb_patches_w), fill_value=0)

embeddings = patch_embeddings
embeddings += self.position_embedding(position_ids)
embeddings += self.position_embedding(mx.array(position_ids))
return embeddings


class VisionModel(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.model_type = config.model_type
if self.model_type != "idefics2":
raise ValueError(f"Unsupported model type: {self.model_type}")
Expand All @@ -226,18 +184,37 @@ def __init__(self, config: VisionConfig):
def __call__(
self,
x: mx.array,
patch_attention_mask: Optional[mx.array] = None,
output_hidden_states: Optional[bool] = None,
) -> mx.array:
x = self.embeddings(x)

B, L, D, C = x.shape
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = mx.ones(
(
B,
L // patch_size,
D // patch_size,
)
)

x = self.embeddings(x, mask=None)

encoder_states = (x,) if output_hidden_states else None
patch_size = self.config.patch_size

mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)

for l in self.encoder.layers:
x = l(x, mask=None)
for layers in self.encoder.layers:
x = layers(x, mask=None)
if output_hidden_states:
encoder_states = encoder_states + (x,)

pooler_output = self.post_layernorm(x)
pooler_output = self.post_layernorm(x[:, 0, :])

return pooler_output, x, encoder_states

Expand Down

0 comments on commit f6e4dc7

Please sign in to comment.