Skip to content

Commit

Permalink
2x speed up for qwen2-vl (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy authored Oct 16, 2024
1 parent 0a39868 commit 51b0e7c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions mlx_vlm/models/qwen2_vl/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def _merge_input_ids_with_image_features(

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = input_ids == image_token_index
inputs_embeds = np.array(inputs_embeds.astype(mx.float32))
inputs_embeds[image_positions] = image_features
image_indices = np.where(image_positions)[1].tolist()
inputs_embeds[:, image_indices, :] = image_features.astype(mx.float32)

# TODO: Add video features

return mx.array(inputs_embeds)
return inputs_embeds

def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/qwen2_vl/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __call__(
# Concatenate the cu_seqlens for all items in the batch
cu_seqlens = mx.concatenate(cu_seqlens)

cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32))
cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)

encoder_states = (hidden_states,) if output_hidden_states else None
Expand Down

0 comments on commit 51b0e7c

Please sign in to comment.