Skip to content
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
Expand Down
586 changes: 586 additions & 0 deletions python/sglang/srt/configs/qwen3_vl.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def get_rope_index(

time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type == "qwen2_vl":
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
Expand Down
49 changes: 43 additions & 6 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
Expand All @@ -522,7 +523,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""

other_info = {}
if mm_inputs_list is None:
return None

Expand All @@ -532,7 +533,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]

embeddings, masks = [], []
embeddings, masks, deepstack_embeddings = [], [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
Expand Down Expand Up @@ -578,6 +579,12 @@ def embed_mm_inputs(
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)

if use_deepstack and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding]
masks += [mask]

Expand All @@ -591,13 +598,37 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids)

# 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):

# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)

input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)

other_info["input_deepstack_embeds"] = input_deepstack_embeds

for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds

if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)

return inputs_embeds, other_info


def general_mm_embed_routine(
Expand All @@ -609,6 +640,7 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -620,6 +652,7 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model

Returns:
Expand All @@ -645,16 +678,20 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds = embed_mm_inputs(
inputs_embeds, other_info = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
Expand Down
Loading
Loading