Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qwen2 vl #2689

Merged
merged 18 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)
- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)
- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)
- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)
- [Opt](https://huggingface.co/facebook/opt-6.7b)
- [T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape.",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1730164250,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native",
"usage": {
"completion_tokens": 58,
"prompt_tokens": 349,
"total_tokens": 407
}
}
42 changes: 42 additions & 0 deletions integration-tests/models/test_flash_qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest


@pytest.fixture(scope="module")
def flash_qwen2_vl_handle(launcher):
with launcher("Qwen/Qwen2-VL-7B-Instruct", cuda_graphs=[0]) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_qwen2(flash_qwen2_vl_handle):
await flash_qwen2_vl_handle.health(300)
return flash_qwen2_vl_handle.client


@pytest.mark.private
async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
response = await flash_qwen2.chat(
max_tokens=100,
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
)

assert (
response.choices[0].message.content
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)

assert response == response_snapshot
29 changes: 29 additions & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,39 @@ impl Paligemma {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2VlVisionConfig {
pub(crate) depth: usize,
pub(crate) embed_dim: usize,
pub(crate) mlp_ratio: usize,
pub(crate) num_heads: usize,
pub(crate) in_chans: usize,
pub(crate) hidden_size: usize,
pub(crate) patch_size: usize,
pub(crate) spatial_merge_size: usize,
pub(crate) spatial_patch_size: usize,
pub(crate) temporal_patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2Vl {
pub(crate) vision_config: Qwen2VlVisionConfig,
}

impl Qwen2Vl {
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
let num_pixels = height * width;
num_pixels / self.vision_config.patch_size.pow(2)
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
Qwen2Vl(Qwen2Vl),
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,
Expand Down
8 changes: 7 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ fn image_tokens(
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
Expand All @@ -620,7 +624,9 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def static(cls, config, dim, base, device):

if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
Expand Down
20 changes: 20 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
Expand Down Expand Up @@ -275,6 +278,11 @@ class ModelType(enum.Enum):
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
OPT = {
"type": "opt",
"name": "Opt",
Expand Down Expand Up @@ -1193,6 +1201,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == QWEN2_VL:
return VlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def forward(
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.mrope_section = (
config.rope_scaling.get("mrope_section", None)
if config.rope_scaling is not None
else None
)
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads

Expand Down Expand Up @@ -122,6 +127,17 @@ def forward(
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

if self.mrope_section is not None:
# if mrope_section is set, we need to split the cos and sin into 3 parts and concatenate them in a specific order
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

if prefill_cache_indices is not None:
Expand Down Expand Up @@ -270,9 +286,6 @@ def __init__(self, prefix: str, config, weights):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
Expand All @@ -296,7 +309,7 @@ def __init__(self, prefix: str, config, weights):

def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
Expand All @@ -307,13 +320,16 @@ def forward(
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = inputs_embeds

# Get rotary cos and sin for this forward
# Avoid to index in each layer
# flatten position ids from 2D to 1D
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
position_ids.flatten(), true_max_s, hidden_states.dtype
)
# reshape back to 2D if the position_ids were 2D
if position_ids.size(0) != cos.size(0):
cos = cos.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)
sin = sin.view(position_ids.size(0), position_ids.size(-1), -1).unsqueeze(2)

residual = None
for i, layer in enumerate(self.layers):
Expand Down Expand Up @@ -352,6 +368,12 @@ def __init__(self, prefix: str, config, weights):
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)

self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)

self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
Expand Down Expand Up @@ -382,8 +404,10 @@ def forward(
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)

inputs_embeds = self.embed_tokens(input_ids)

hidden_states = self.model(
input_ids,
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ def forward(
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def forward(
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
Expand Down
Loading
Loading