Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Branch #26

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
116 commits
Select commit Hold shift + click to select a range
a091e2d
[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)
ElizaWszola Sep 16, 2024
837c196
[Frontend] Expose revision arg in OpenAI server (#8501)
lewtun Sep 16, 2024
acd5511
[BugFix] Fix clean shutdown issues (#8492)
njhill Sep 16, 2024
781e3b9
[Bugfix][Kernel] Fix build for sm_60 in GGUF kernel (#8506)
sasha0552 Sep 16, 2024
5d73ae4
[Kernel] AQ AZP 3/4: Asymmetric quantization kernels (#7270)
ProExpertProg Sep 16, 2024
2759a43
[doc] update doc on testing and debugging (#8514)
youkaichao Sep 16, 2024
47f5e03
[Bugfix] Bind api server port before starting engine (#8491)
kevin314 Sep 16, 2024
5478c4b
[perf bench] set timeout to debug hanging (#8516)
simon-mo Sep 16, 2024
5ce45eb
[misc] small qol fixes for release process (#8517)
simon-mo Sep 16, 2024
cca6164
[Bugfix] Fix 3.12 builds on main (#8510)
joerunde Sep 17, 2024
546034b
[refactor] remove triton based sampler (#8524)
simon-mo Sep 17, 2024
1c1bb38
[Frontend] Improve Nullable kv Arg Parsing (#8525)
alex-jw-brooks Sep 17, 2024
ee2bcea
[Misc][Bugfix] Disable guided decoding for mistral tokenizer (#8521)
ywang96 Sep 17, 2024
99aa4ed
[torch.compile] register allreduce operations as custom ops (#8526)
youkaichao Sep 17, 2024
cbdb252
[Misc] Limit to ray[adag] 2.35 to avoid backward incompatible change …
ruisearch42 Sep 17, 2024
1b6de83
[Benchmark] Support sample from HF datasets and image input for bench…
Isotr0py Sep 17, 2024
1009e93
[Encoder decoder] Add cuda graph support during decoding for encoder-…
sroy745 Sep 17, 2024
9855b99
[Feature][kernel] tensor parallelism with bitsandbytes quantization (…
chenqianfzh Sep 17, 2024
a54ed80
[Model] Add mistral function calling format to all models loaded with…
patrickvonplaten Sep 17, 2024
56c3de0
[Misc] Don't dump contents of kvcache tensors on errors (#8527)
njhill Sep 17, 2024
98f9713
[Bugfix] Fix TP > 1 for new granite (#8544)
joerunde Sep 17, 2024
fa0c114
[doc] improve installation doc (#8550)
youkaichao Sep 17, 2024
09deb47
[CI/Build] Excluding kernels/test_gguf.py from ROCm (#8520)
alexeykondrat Sep 17, 2024
8110e44
[Kernel] Change interface to Mamba causal_conv1d_update for continuou…
tlrmchlsmth Sep 17, 2024
95965d3
[CI/Build] fix Dockerfile.cpu on podman (#8540)
dtrifiro Sep 18, 2024
e351572
[Misc] Add argument to disable FastAPI docs (#8554)
Jeffwan Sep 18, 2024
6ffa3f3
[CI/Build] Avoid CUDA initialization (#8534)
DarkLight1337 Sep 18, 2024
9d104b5
[CI/Build] Update Ruff version (#8469)
aarnphm Sep 18, 2024
7c7714d
[Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH (#…
alexm-redhat Sep 18, 2024
a8c1d16
[Core] *Prompt* logprobs support in Multi-step (#8199)
afeldman-nm Sep 18, 2024
d65798f
[Core] zmq: bind only to 127.0.0.1 for local-only usage (#8543)
russellb Sep 18, 2024
e18749f
[Model] Support Solar Model (#8386)
shing100 Sep 18, 2024
b3195bc
[AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380)
gshtras Sep 18, 2024
db9120c
[Kernel] Change interface to Mamba selective_state_update for continu…
tlrmchlsmth Sep 18, 2024
d9cd78e
[BugFix] Nonzero exit code if MQLLMEngine startup fails (#8572)
njhill Sep 18, 2024
0d47bf3
[Bugfix] add `dead_error` property to engine client (#8574)
joerunde Sep 18, 2024
4c34ce8
[Kernel] Remove marlin moe templating on thread_m_blocks (#8573)
tlrmchlsmth Sep 19, 2024
3118f63
[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata const…
sroy745 Sep 19, 2024
02c9afa
Revert "[Misc][Bugfix] Disable guided decoding for mistral tokenizer"…
ywang96 Sep 19, 2024
c52ec5f
[Bugfix] fixing sonnet benchmark bug in benchmark_serving.py (#8616)
KuntaiDu Sep 19, 2024
855c8ae
[MISC] remove engine_use_ray in benchmark_throughput.py (#8615)
jikunshang Sep 19, 2024
76515f3
[Frontend] Use MQLLMEngine for embeddings models too (#8584)
njhill Sep 19, 2024
9cc373f
[Kernel][Amd] Add fp8 kv cache support for rocm custom paged attentio…
charlifu Sep 19, 2024
e42c634
[Core] simplify logits resort in _apply_top_k_top_p (#8619)
hidva Sep 19, 2024
ea4647b
[Doc] Add documentation for GGUF quantization (#8618)
Isotr0py Sep 19, 2024
9e99407
Create SECURITY.md (#8642)
simon-mo Sep 19, 2024
6cb748e
[CI/Build] Re-enabling Entrypoints tests on ROCm, excluding ones that…
alexeykondrat Sep 19, 2024
de6f90a
[Misc] guard against change in cuda library name (#8609)
bnellnm Sep 19, 2024
18ae428
[Bugfix] Fix Phi3.5 mini and MoE LoRA inference (#8571)
garg-amit Sep 20, 2024
9e5ec35
[bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetad…
SolitaryThinker Sep 20, 2024
260d40b
[Core] Support Lora lineage and base model metadata management (#6315)
Jeffwan Sep 20, 2024
3b63de9
[Model] Add OLMoE (#7922)
Muennighoff Sep 20, 2024
2940afa
[CI/Build] Removing entrypoints/openai/test_embedding.py test from RO…
alexeykondrat Sep 20, 2024
b28298f
[Bugfix] Validate SamplingParam n is an int (#8548)
saumya-saran Sep 20, 2024
035fa89
[Misc] Show AMD GPU topology in `collect_env.py` (#8649)
DarkLight1337 Sep 20, 2024
2874bac
[Bugfix] Config got an unexpected keyword argument 'engine' (#8556)
Juelianqvq Sep 20, 2024
b4e4eda
[Bugfix][Core] Fix tekken edge case for mistral tokenizer (#8640)
patrickvonplaten Sep 20, 2024
7c8566a
[Doc] neuron documentation update (#8671)
omrishiv Sep 20, 2024
7f9c890
[Hardware][AWS] update neuron to 2.20 (#8676)
omrishiv Sep 20, 2024
0f961b3
[Bugfix] Fix incorrect llava next feature size calculation (#8496)
zyddnys Sep 20, 2024
0057894
[Core] Rename `PromptInputs` and `inputs`(#8673)
DarkLight1337 Sep 21, 2024
d4bf085
[MISC] add support custom_op check (#8557)
jikunshang Sep 21, 2024
0455c46
[Core] Factor out common code in `SequenceData` and `Sequence` (#8675)
DarkLight1337 Sep 21, 2024
0faab90
[beam search] add output for manually checking the correctness (#8684)
youkaichao Sep 21, 2024
71c6049
[Kernel] Build flash-attn from source (#8245)
ProExpertProg Sep 21, 2024
5e85f4f
[VLM] Use `SequenceData.from_token_counts` to create dummy data (#8687)
DarkLight1337 Sep 21, 2024
4dfdf43
[Doc] Fix typo in AMD installation guide (#8689)
Imss27 Sep 21, 2024
ec4aaad
[Kernel][Triton][AMD] Remove tl.atomic_add from awq_gemm_kernel, 2-5x…
rasmith Sep 21, 2024
9dc7c6c
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
divakar-amd Sep 21, 2024
d66ac62
[Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (…
tlrmchlsmth Sep 21, 2024
13d88d4
[Bugfix] Refactor composite weight loading logic (#8656)
Isotr0py Sep 22, 2024
0e40ac9
[ci][build] fix vllm-flash-attn (#8699)
youkaichao Sep 22, 2024
06ed281
[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)
DarkLight1337 Sep 22, 2024
8ca5051
[Misc] Use NamedTuple in Multi-image example (#8705)
alex-jw-brooks Sep 22, 2024
ca2b628
[MISC] rename CudaMemoryProfiler to DeviceMemoryProfiler (#8703)
ji-huazhong Sep 22, 2024
5b59532
[Model][VLM] Add LLaVA-Onevision model support (#8486)
litianjian Sep 22, 2024
c6bd70d
[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
LiuXiaoxuanPKU Sep 22, 2024
d4a2ac8
[build] enable existing pytorch (for GH200, aarch64, nightly) (#8713)
youkaichao Sep 22, 2024
92ba7e7
[misc] upgrade mistral-common (#8715)
youkaichao Sep 22, 2024
3dda7c2
[Bugfix] Avoid some bogus messages RE CUTLASS's revision when buildin…
tlrmchlsmth Sep 23, 2024
57a0702
[Bugfix] Fix CPU CMake build (#8723)
ProExpertProg Sep 23, 2024
d23679e
[Bugfix] fix docker build for xpu (#8652)
yma11 Sep 23, 2024
9b8c8ba
[Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657)
alex-jw-brooks Sep 23, 2024
e551ca1
[Hardware][CPU] Refactor CPU model runner (#8729)
Isotr0py Sep 23, 2024
3e83c12
[Bugfix][CPU] fix missing input intermediate_tensors in the cpu_model…
bigPYJ1151 Sep 23, 2024
a79e522
[Model] Support pp for qwen2-vl (#8696)
liuyanyi Sep 23, 2024
f2bd246
[VLM] Fix paligemma, fuyu and persimmon with transformers 4.45 : use …
janimo Sep 23, 2024
ee5f34b
[CI/Build] use setuptools-scm to set __version__ (#4738)
dtrifiro Sep 23, 2024
86e9c8d
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GP…
LucasWilkinson Sep 23, 2024
9b0e3ec
[Kernel][LoRA] Add assertion for punica sgmv kernels (#7585)
jeejeelee Sep 23, 2024
b05f5c9
[Core] Allow IPv6 in VLLM_HOST_IP with zmq (#8575)
russellb Sep 23, 2024
5f7bb58
Fix typical acceptance sampler with correct recovered token ids (#8562)
jiqing-feng Sep 23, 2024
1a2aef3
Add output streaming support to multi-step + async while ensuring Req…
alexm-redhat Sep 23, 2024
530821d
[Hardware][AMD] ROCm6.2 upgrade (#8674)
hongxiayang Sep 24, 2024
88577ac
Fix tests in test_scheduler.py that fail with BlockManager V2 (#8728)
sroy745 Sep 24, 2024
0250dd6
re-implement beam search on top of vllm core (#8726)
youkaichao Sep 24, 2024
3185fb0
Revert "[Core] Rename `PromptInputs` to `PromptType`, and `inputs` to…
simon-mo Sep 24, 2024
b8747e8
[MISC] Skip dumping inputs when unpicklable (#8744)
comaniac Sep 24, 2024
3f06bae
[Core][Model] Support loading weights by ID within models (#7931)
petersalas Sep 24, 2024
8ff7ced
[Model] Expose Phi3v num_crops as a mm_processor_kwarg (#8658)
alex-jw-brooks Sep 24, 2024
cc4325b
[Bugfix] Fix potentially unsafe custom allreduce synchronization (#8558)
hanzhi713 Sep 24, 2024
a928ded
[Kernel] Split Marlin MoE kernels into multiple files (#8661)
ElizaWszola Sep 24, 2024
2529d09
[Frontend] Batch inference for llm.chat() API (#8648)
aandyw Sep 24, 2024
72fc97a
[Bugfix] Fix torch dynamo fixes caused by `replace_parameters` (#8748)
LucasWilkinson Sep 24, 2024
2467b64
[CI/Build] fix setuptools-scm usage (#8771)
dtrifiro Sep 24, 2024
1e7d5c0
[misc] soft drop beam search (#8763)
youkaichao Sep 24, 2024
13f9f7a
[[Misc]Upgrade bitsandbytes to the latest version 0.44.0 (#8768)
jeejeelee Sep 25, 2024
01b6f9e
[Core][Bugfix] Support prompt_logprobs returned with speculative deco…
tjohnson31415 Sep 25, 2024
6da1ab6
[Core] Adding Priority Scheduling (#5958)
apatke Sep 25, 2024
6e0c9d6
[Bugfix] Use heartbeats instead of health checks (#8583)
joerunde Sep 25, 2024
ee777d9
Fix test_schedule_swapped_simple in test_scheduler.py (#8780)
sroy745 Sep 25, 2024
b452247
[Bugfix][Kernel] Implement acquire/release polyfill for Pascal (#8776)
sasha0552 Sep 25, 2024
fc3afc2
Fix tests in test_chunked_prefill_scheduler which fail with BlockMana…
sroy745 Sep 25, 2024
e3dd069
[BugFix] Propagate 'trust_remote_code' setting in internvl and minicp…
zifeitong Sep 25, 2024
c239536
[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend (#8770)
Isotr0py Sep 25, 2024
3e073e6
[Bugfix] load fc bias from config for eagle (#8790)
sohamparikh Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[BugFix] Propagate 'trust_remote_code' setting in internvl and minicp…
  • Loading branch information
zifeitong authored Sep 25, 2024
commit e3dd0692fa2c803cd6f59a88d2fdf8bca26d8d96
15 changes: 9 additions & 6 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
Expand Down Expand Up @@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail=use_thumbnail) for img in data
]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False,
return_tensors="pt")[0]
Expand All @@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

seq_data = dummy_seq_data_for_clip(
vision_config,
Expand Down
137 changes: 108 additions & 29 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -52,6 +53,7 @@
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
Expand All @@ -64,6 +66,17 @@
}


class MiniCPMVImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image

# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
im_end_id: torch.Tensor
slice_start_id: NotRequired[torch.Tensor]
slice_end_id: NotRequired[torch.Tensor]


class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
"""
Expand All @@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""


MiniCPMVImageInputs = MiniCPMVImagePixelInputs

DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


Expand Down Expand Up @@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor,
return x


def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVImageInput(image=image,
im_start_id=torch.tensor(
tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))


def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)

Expand All @@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_token_counts((0, seq_len))


def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
num_images: int):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
image = _build_image_input(ctx,
image=Image.new("RGB", (width, height),
color=0))
return {"image": [image] if num_images == 1 else [image] * num_images}


def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
Expand All @@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images = mm_counts["image"]

seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)

return seq_data, mm_data

Expand All @@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_processor = cached_get_image_processor(model_config.tokenizer)

def get_placeholder(image_size: Tuple[int, int], num_image: int):
Expand Down Expand Up @@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)

multi_modal_data["image"] = [
_build_image_input(ctx, image) for image in images
]

llm_inputs = LLMInputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
Expand All @@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
return llm_inputs


def input_mapper_for_minicpmv(ctx: InputContext, data: object):
model_config = ctx.model_config

image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")

if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data

if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
batch_data["im_end_id"] = data[0]["im_end_id"]
if "slice_start_id" in data[0]:
batch_data["slice_start_id"] = data[0]["slice_start_id"]
batch_data["slice_end_id"] = data[0]["slice_end_id"]

return MultiModalInputs(batch_data)


class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
Expand Down Expand Up @@ -365,7 +429,7 @@ def __init__(
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
image_inputs: Optional[MiniCPMVImagePixelInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
Expand Down Expand Up @@ -393,14 +457,20 @@ def get_embedding(

return vlm_embedding, vision_hidden_states

def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
if hasattr(tokenizer, "slice_start_id"):
start_cond |= (input_ids == tokenizer.slice_start_id)
end_cond |= (input_ids == tokenizer.slice_end_id)
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])

image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
Expand All @@ -419,7 +489,7 @@ def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
) -> Optional[MiniCPMVImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])

Expand Down Expand Up @@ -456,8 +526,17 @@ def _parse_and_validate_inputs(
if len(pixel_values_flat) == 0:
return None

return MiniCPMVImageInputs(
image_bounds=self._get_image_bounds(input_ids),
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None:
return None

return MiniCPMVImagePixelInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
)
Expand Down Expand Up @@ -564,8 +643,8 @@ def get_vision_embedding(
) -> torch.Tensor:
raise NotImplementedError

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
raise NotImplementedError

def is_default_weight_loading(self, name: str) -> bool:
Expand Down Expand Up @@ -654,8 +733,8 @@ def get_vision_embedding(
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]

return self.get_vision_embedding(pixel_values)
Expand Down Expand Up @@ -713,8 +792,8 @@ def get_vision_embedding(
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]

Expand Down Expand Up @@ -807,8 +886,8 @@ def get_vision_embedding(
).last_hidden_state
return vision_embedding

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]

Expand Down Expand Up @@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool:
}


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
Expand Down
15 changes: 9 additions & 6 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
Expand Down Expand Up @@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return MultiModalInputs()

model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
Expand Down Expand Up @@ -824,8 +826,9 @@ def dummy_data_for_qwen(
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
Expand Down