Skip to content

Commit

Permalink
update multimodal features merging
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed May 25, 2024
1 parent bc34968 commit 2fd4ecd
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(
queries, keys, values, scale=self.scale
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)


class Idefics2PerceiverLayer(nn.Module):
Expand Down Expand Up @@ -132,7 +132,7 @@ def __call__(
latents = self.input_latents_norm(x)
context = self.input_context_norm(hidden_states)

latents, _ = self.self_attn(latents, context, mask=mask)
latents = self.self_attn(latents, context, mask=mask)

latents = x + latents
r = latents
Expand Down Expand Up @@ -219,7 +219,7 @@ def get_input_embeddings(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)

image_features = hidden_state[-1].astype(pixel_values.dtype)
image_features = pooler_output[None, :].astype(pixel_values.dtype)

image_features = self.connector(image_features, mask=None)

Expand All @@ -229,20 +229,26 @@ def get_input_embeddings(
return final_inputs_embeds

def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):

image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape
special_image_token_mask = input_ids == image_token_index

reshaped_image_hidden_states = image_features.reshape(-1, embed_dim)
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()

text_segments = []
start_idx = 0

# Find the positions of the <image> tokens in the input_ids
image_token_positions = mx.array(np.where(special_image_token_mask)[1])
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

# Advanced indexing to place reshaped image features at the corresponding positions
inputs_embeds[0, image_token_positions, :] = reshaped_image_hidden_states
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

return inputs_embeds
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)

def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
Expand Down

0 comments on commit 2fd4ecd

Please sign in to comment.