Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qwen2 vl #2689

Merged
merged 18 commits into from
Oct 30, 2024
Merged

Support qwen2 vl #2689

merged 18 commits into from
Oct 30, 2024

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Oct 24, 2024

This is a work in progress PR to support qwen2-vl. Currently these changes include loading the model weights and a functioning vision model. Remaining work is to adjust the existing qwen2 model to handle multimodal requests/positional embeddings.

status:

  • load weights
  • support prefill warmup run
  • accept chat image request and process image
  • align vision model output with reference impl
  • correctly merge the processed image and text model
  • avoid any reshapes and allocations during runtime possible

remaining:

  • resolve remaining bug with position ids
  • align test output with reference
  • cleanup remaining todos/refactors/improvements

further

  • make improvements
  • improve test coverage

@@ -144,7 +144,7 @@ def load_qkv(
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
bias = weights.get_tensor(f"{prefix}.bias")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong, no ?

We get the whole bias on Row Parallel, for column, you need to take the actual slice, which for qkv you need to follow the same layout as the weights I think (except on dim=0 instead of dim=1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after comparing with transformers it seems like weights.get_tensor(f"{prefix}.bias") and weights.get_sharded(f"{prefix}.qkv.bias", dim=0) return the exact bias as the one in the reference.

I've reverted the change within tensor_parallel.py::load_qkv in favor of setting the bias after creating the linear in qwen_vl.py via weights.get_sharded(f"{prefix}.qkv.bias", dim=0)

for reference:

self.qkv = TensorParallelColumnLinear.load_qkv(
    config,
    prefix=f"{prefix}.qkv",
    weights=weights,
    bias=False,
    num_heads=self.num_heads,
    num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)

this hopefully makes the qkv loading a bit more clear

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# TODO: correctly handle the multimodal case
if False:
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be correct. Ignore tiny differences there, this code is exactly what you have underneath (and much more efficient).

I read transformers comment on this, it seems from what I'm reading that they are just applying part of the tensors there, so a regular slicing should do the work.

The problem with the other part, is that our cos, sin are layed out differently than theirs, so you're gonna have issues keeping transformers code and merge it with our own.

hidden_states = self.embed_tokens(input_ids)

# if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids
if inputs_embeds is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, remove this.

The way it's done here, is that we life input_ids to the parent class, and this always takes input_embeds. Makes signatures much cleaner. (It's already done for llama this way if you want to check)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed this is much cleaner. I've updated the classes in the latest commit to follow the same pattern as llama

@@ -306,12 +335,24 @@ def forward(
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No attention_mask

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

Comment on lines 405 to 415
image_mask = (
(input_ids == self.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
# input embeddings are masked with image embeddings
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does

inputs_embeds[input_ids == self.image_token_id] = image_embeds

Doesn't work ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha yea thats a much better way to write this! I've included this change along with a overall better rewrite of this logic in the latest commit.

attention_mask = torch.ones_like(
input_ids, dtype=torch.bool, device=input_ids.device
)
inputs_embeds = self.text_model.embed_tokens(input_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you lift this, this will be gone from the text_model and be here directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in latest commit

image_index, video_index = 0, 0

for i, input_ids in enumerate(total_input_ids):
if attention_mask is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole thing is extremely poor code (lots of loops, lots of CPU/GPU back&forth).

I think ditching it altogether will be easier than trying to adapt it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea agreed, in the latest commit i've moved this logic into a function get_position_ids and rewritten a more simple version that avoids most of the gpu/cpu copies. I'll revisit later to see if I can simplify further (avoid the loop) but the changes may provide a bit of a performance and readability improvement.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@drbh drbh force-pushed the support-qwen2-vl branch from d1bc32b to f2a1b1b Compare October 28, 2024 16:30
@drbh drbh marked this pull request as ready for review October 29, 2024 17:50
@drbh drbh requested a review from Narsil October 30, 2024 14:23
Comment on lines +373 to +374
position_ids = position_ids.repeat(3, 1, 1).clone()
batch.position_ids = position_ids[0, 0, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very wrong, no ?

position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
)
batch.position_ids = position_ids[0, 0, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create so many position ids, just to discard them ?

)
self.device = weights.device

def get_position_ids(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems overly complex and bloated.

Let's keep this if it's working, but it's definitely fixable I think

if hasattr(self.model, "get_position_ids"):
if position_ids.shape[0] != 1:
position_ids = self.model.get_position_ids(
input_ids.unsqueeze(0), batch.image_grid_thw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No unsqueeze, fix the function.

@drbh
Copy link
Collaborator Author

drbh commented Oct 30, 2024

merging to enable qwen2-vl

Will follow up with an improvement PR soon. specifically:

  • improve decode
  • improve vision head
  • improve batch to handle multi dimensional position ids
  • remove complex position logic if possible

@drbh drbh merged commit befd9f6 into main Oct 30, 2024
12 of 13 checks passed
@drbh drbh deleted the support-qwen2-vl branch October 30, 2024 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants