-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support qwen2 vl #2689
Support qwen2 vl #2689
Conversation
@@ -144,7 +144,7 @@ def load_qkv( | |||
num_key_value_heads=num_key_value_heads, | |||
) | |||
if bias: | |||
raise NotImplementedError("packed_qkv only implemented for baichuan") | |||
bias = weights.get_tensor(f"{prefix}.bias") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong, no ?
We get the whole bias on Row Parallel, for column, you need to take the actual slice, which for qkv you need to follow the same layout as the weights I think (except on dim=0 instead of dim=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after comparing with transformers it seems like weights.get_tensor(f"{prefix}.bias")
and weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
return the exact bias as the one in the reference.
I've reverted the change within tensor_parallel.py::load_qkv
in favor of setting the bias after creating the linear in qwen_vl.py
via weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
for reference:
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
this hopefully makes the qkv loading a bit more clear
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) | ||
# TODO: correctly handle the multimodal case | ||
if False: | ||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be correct. Ignore tiny differences there, this code is exactly what you have underneath (and much more efficient).
I read transformers comment on this, it seems from what I'm reading that they are just applying part of the tensors there, so a regular slicing should do the work.
The problem with the other part, is that our cos, sin are layed out differently than theirs, so you're gonna have issues keeping transformers code and merge it with our own.
hidden_states = self.embed_tokens(input_ids) | ||
|
||
# if inputs_embeds are supplied from an external model (vision model) then avoid embedding input_ids | ||
if inputs_embeds is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, remove this.
The way it's done here, is that we life input_ids to the parent class, and this always takes input_embeds
. Makes signatures much cleaner. (It's already done for llama this way if you want to check)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes agreed this is much cleaner. I've updated the classes in the latest commit to follow the same pattern as llama
@@ -306,12 +335,24 @@ def forward( | |||
max_s: int, | |||
true_max_s: int, | |||
prefill_cache_indices: Optional[torch.Tensor], | |||
inputs_embeds: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No attention_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in latest commit
image_mask = ( | ||
(input_ids == self.image_token_id) | ||
.unsqueeze(-1) | ||
.expand_as(inputs_embeds) | ||
.to(inputs_embeds.device) | ||
) | ||
image_embeds = image_embeds.to( | ||
inputs_embeds.device, inputs_embeds.dtype | ||
) | ||
# input embeddings are masked with image embeddings | ||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does
inputs_embeds[input_ids == self.image_token_id] = image_embeds
Doesn't work ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ha yea thats a much better way to write this! I've included this change along with a overall better rewrite of this logic in the latest commit.
attention_mask = torch.ones_like( | ||
input_ids, dtype=torch.bool, device=input_ids.device | ||
) | ||
inputs_embeds = self.text_model.embed_tokens(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you lift this, this will be gone from the text_model
and be here directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in latest commit
image_index, video_index = 0, 0 | ||
|
||
for i, input_ids in enumerate(total_input_ids): | ||
if attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole thing is extremely poor code (lots of loops, lots of CPU/GPU back&forth).
I think ditching it altogether will be easier than trying to adapt it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea agreed, in the latest commit i've moved this logic into a function get_position_ids
and rewritten a more simple version that avoids most of the gpu/cpu copies. I'll revisit later to see if I can simplify further (avoid the loop) but the changes may provide a bit of a performance and readability improvement.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
d1bc32b
to
f2a1b1b
Compare
position_ids = position_ids.repeat(3, 1, 1).clone() | ||
batch.position_ids = position_ids[0, 0, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems very wrong, no ?
position_ids = self.model.get_position_ids( | ||
input_ids.unsqueeze(0), batch.image_grid_thw | ||
) | ||
batch.position_ids = position_ids[0, 0, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why create so many position ids, just to discard them ?
) | ||
self.device = weights.device | ||
|
||
def get_position_ids( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems overly complex and bloated.
Let's keep this if it's working, but it's definitely fixable I think
if hasattr(self.model, "get_position_ids"): | ||
if position_ids.shape[0] != 1: | ||
position_ids = self.model.get_position_ids( | ||
input_ids.unsqueeze(0), batch.image_grid_thw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No unsqueeze, fix the function.
merging to enable qwen2-vl Will follow up with an improvement PR soon. specifically:
|
This is a work in progress PR to support qwen2-vl. Currently these changes include loading the model weights and a functioning vision model. Remaining work is to adjust the existing qwen2 model to handle multimodal requests/positional embeddings.
status:
remaining:
further