From f6e4dc733931b46d9e029fe7d35f300a0316f35d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 2 May 2024 07:48:02 +0200 Subject: [PATCH] working model --- mlx_vlm/models/idefics2/idefics2.py | 27 +++----- mlx_vlm/models/idefics2/language.py | 2 +- mlx_vlm/models/idefics2/vision.py | 103 +++++++++++----------------- 3 files changed, 52 insertions(+), 80 deletions(-) diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 5331fc6..c5911f6 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -101,7 +101,7 @@ def __call__( values = mx.concatenate([value_cache, values], axis=2) output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + queries, keys, values, scale=self.scale ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) @@ -128,12 +128,11 @@ def __call__( x: mx.array, hidden_states: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: latents = self.input_latents_norm(x) context = self.input_context_norm(hidden_states) - latents, cache = self.self_attn(latents, context, mask=mask, cache=cache) + latents, _ = self.self_attn(latents, context, mask=mask) latents = x + latents r = latents @@ -141,7 +140,7 @@ def __call__( latents = self.post_attention_layernorm(latents) latents = self.mlp(latents) latents = r + latents - return latents, cache + return latents class Idefics2PerceiverResampler(nn.Module): @@ -157,18 +156,15 @@ def __init__(self, config: ModelConfig): ] self.norm = nn.RMSNorm(self.hidden_size, eps=config.text_config.rms_norm_eps) - def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache=None): - - if cache is None: - cache = [None] * len(self.layers) + def __call__(self, x: mx.array, mask: Optional[mx.array] = None): h = mx.expand_dims(self.latents, axis=0) h = mx.repeat(h, x.shape[0], axis=0) - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, x, mask=mask, cache=cache[e]) + for layer in self.layers: + h = layer(h, x, mask=mask) - return self.norm(h), cache + return self.norm(h) class MLP(nn.Module): @@ -218,13 +214,14 @@ def get_input_embeddings( return self.language_model(input_ids) inputs_embeds = self.language_model.embed_tokens(input_ids) - *_, hidden_state = self.vision_model( + print(pixel_values[0].shape) + pooler_output, embeddings, hidden_state = self.vision_model( pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True ) image_features = hidden_state[-1].astype(pixel_values.dtype) - image_features, _ = self.connector(image_features, mask=None) + image_features = self.connector(image_features, mask=None) final_inputs_embeds = self._prepare_inputs_for_multimodal( image_features, inputs_embeds, input_ids @@ -243,9 +240,7 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id image_token_positions = mx.array(np.where(special_image_token_mask)[1]) # Advanced indexing to place reshaped image features at the corresponding positions - inputs_embeds[0, image_token_positions, :] = reshaped_image_hidden_states[ - : len(image_token_positions) - ] + inputs_embeds[0, image_token_positions, :] = reshaped_image_hidden_states return inputs_embeds diff --git a/mlx_vlm/models/idefics2/language.py b/mlx_vlm/models/idefics2/language.py index f412cde..b4a0dcb 100644 --- a/mlx_vlm/models/idefics2/language.py +++ b/mlx_vlm/models/idefics2/language.py @@ -18,7 +18,7 @@ class TextConfig: vocab_size: int num_key_value_heads: int rope_theta: float = 10000.0 - rope_traditional: bool = True + rope_traditional: bool = False tie_word_embeddings: bool = False @classmethod diff --git a/mlx_vlm/models/idefics2/vision.py b/mlx_vlm/models/idefics2/vision.py index a7f67fc..7e4dae4 100644 --- a/mlx_vlm/models/idefics2/vision.py +++ b/mlx_vlm/models/idefics2/vision.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn +import numpy as np @dataclass @@ -79,59 +80,17 @@ def __init__( self.v_proj = nn.Linear(value_input_dims, value_dims, bias=True) self.out_proj = nn.Linear(value_dims, value_output_dims, bias=True) - def __call__(self, queries, keys, values, mask=None): - queries = self.q_proj(queries) - keys = self.k_proj(keys) - values = self.v_proj(values) + def __call__(self, x: mx.array, mask=None): + B, L, _ = x.shape + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.out_proj(output) - -class MHA(nn.Module): - def __init__( - self, - dims: int, - num_heads: int, - bias: bool = False, - ): - super().__init__() - - if (dims % num_heads) != 0: - raise ValueError( - "The input feature dimensions should be divisible by the " - f"number of heads ({dims} % {num_heads}) != 0" - ) - - self.num_heads = num_heads - head_dim = dims // num_heads - self.scale = head_dim**-0.5 - - self.in_proj = nn.Linear(dims, dims * 3, bias=bias) - self.out_proj = nn.Linear(dims, dims, bias=bias) - - def __call__(self, queries: mx.array, kv: mx.array, mask=None, cache=None): - B, L, D = queries.shape - - qkv = self.in_proj(queries) - _, keys, values = mx.split(qkv, 3, axis=-1) - - num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask @@ -164,7 +123,7 @@ def __init__(self, config: VisionConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: y = self.layer_norm1(x) - y = self.self_attn(y, y, y, mask) + y = self.self_attn(y, mask) x = x + y y = self.layer_norm2(x) y = self.mlp(y) @@ -190,32 +149,31 @@ def __init__(self, config: VisionConfig): out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, - dilation=1, ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def __call__(self, x: mx.array) -> mx.array: - batch_size, max_im_h, max_im_w, _ = x.shape + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + B, H, W, C = x.shape patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) max_nb_patches_h, max_nb_patches_w = ( - max_im_h // self.patch_size, - max_im_w // self.patch_size, + H // self.patch_size, + W // self.patch_size, ) - position_ids = mx.zeros( - (batch_size, max_nb_patches_h * max_nb_patches_w) - ).astype(mx.uint64) + position_ids = np.full((B, max_nb_patches_h * max_nb_patches_w), fill_value=0) + embeddings = patch_embeddings - embeddings += self.position_embedding(position_ids) + embeddings += self.position_embedding(mx.array(position_ids)) return embeddings class VisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() + self.config = config self.model_type = config.model_type if self.model_type != "idefics2": raise ValueError(f"Unsupported model type: {self.model_type}") @@ -226,18 +184,37 @@ def __init__(self, config: VisionConfig): def __call__( self, x: mx.array, + patch_attention_mask: Optional[mx.array] = None, output_hidden_states: Optional[bool] = None, ) -> mx.array: - x = self.embeddings(x) + + B, L, D, C = x.shape + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = mx.ones( + ( + B, + L // patch_size, + D // patch_size, + ) + ) + + x = self.embeddings(x, mask=None) encoder_states = (x,) if output_hidden_states else None + patch_size = self.config.patch_size + + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) - for l in self.encoder.layers: - x = l(x, mask=None) + for layers in self.encoder.layers: + x = layers(x, mask=None) if output_hidden_states: encoder_states = encoder_states + (x,) - pooler_output = self.post_layernorm(x) + pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states