Skip to content

Commit 7025b11

Browse files
[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
1 parent 5469146 commit 7025b11

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+414
-205
lines changed

tests/conftest.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import sys
55
from collections import UserList
66
from enum import Enum
7-
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
7+
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
8+
TypeVar, Union)
89

910
import pytest
1011
import torch
@@ -27,7 +28,7 @@
2728
from vllm.outputs import RequestOutput
2829
from vllm.sequence import SampleLogprobs
2930
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
30-
is_cpu)
31+
identity, is_cpu)
3132

3233
logger = init_logger(__name__)
3334

@@ -197,6 +198,8 @@ def __init__(
197198
is_embedding_model: bool = False,
198199
is_vision_model: bool = False,
199200
is_encoder_decoder_model: bool = False,
201+
postprocess_inputs: Callable[[BatchEncoding],
202+
BatchEncoding] = identity,
200203
) -> None:
201204
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
202205

@@ -242,12 +245,14 @@ def __init__(
242245
torch_dtype=torch_dtype,
243246
trust_remote_code=True,
244247
)
245-
except Exception:
248+
except Exception as exc:
246249
logger.warning(
247-
"Unable to auto-load processor from HuggingFace for "
248-
"model %s. Using tokenizer instead.", model_name)
250+
"Unable to auto-load HuggingFace processor for model (%s). "
251+
"Using tokenizer instead. Reason: %s", model_name, exc)
249252
self.processor = self.tokenizer
250253

254+
self.postprocess_inputs = postprocess_inputs
255+
251256
def generate(
252257
self,
253258
prompts: List[str],
@@ -267,6 +272,7 @@ def generate(
267272
processor_kwargs["images"] = images[i]
268273

269274
inputs = self.processor(**processor_kwargs)
275+
inputs = self.postprocess_inputs(inputs)
270276

271277
output_ids = self.model.generate(
272278
**self.wrap_device(inputs),
@@ -336,6 +342,7 @@ def generate_greedy_logprobs(
336342
processor_kwargs["images"] = images[i]
337343

338344
inputs = self.processor(**processor_kwargs)
345+
inputs = self.postprocess_inputs(inputs)
339346

340347
output = self.model.generate(
341348
**self.wrap_device(inputs),
@@ -420,6 +427,7 @@ def generate_greedy_logprobs_limit(
420427
processor_kwargs["images"] = images[i]
421428

422429
inputs = self.processor(**processor_kwargs)
430+
inputs = self.postprocess_inputs(inputs)
423431

424432
output = self.model.generate(
425433
**self.wrap_device(inputs),
@@ -552,7 +560,8 @@ def generate(
552560
self,
553561
prompts: List[str],
554562
sampling_params: SamplingParams,
555-
images: Optional[List[Image.Image]] = None,
563+
images: Optional[Union[List[Image.Image],
564+
List[List[Image.Image]]]] = None,
556565
) -> List[Tuple[List[List[int]], List[str]]]:
557566
if images is not None:
558567
assert len(prompts) == len(images)
@@ -587,7 +596,7 @@ def _final_steps_generate_w_logprobs(
587596
for req_output in req_outputs:
588597
for sample in req_output.outputs:
589598
output_str = sample.text
590-
output_ids = sample.token_ids
599+
output_ids = list(sample.token_ids)
591600
output_logprobs = sample.logprobs
592601
outputs.append((output_ids, output_str, output_logprobs))
593602
return outputs
@@ -596,7 +605,8 @@ def generate_w_logprobs(
596605
self,
597606
prompts: List[str],
598607
sampling_params: SamplingParams,
599-
images: Optional[List[Image.Image]] = None,
608+
images: Optional[Union[List[Image.Image],
609+
List[List[Image.Image]]]] = None,
600610
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
601611
assert sampling_params.logprobs is not None
602612

tests/distributed/test_multimodal_broadcast.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
@pytest.mark.parametrize("model, distributed_executor_backend", [
1919
("llava-hf/llava-1.5-7b-hf", "ray"),
2020
("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
21+
("facebook/chameleon-7b", "ray"),
2122
("llava-hf/llava-1.5-7b-hf", "mp"),
2223
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
24+
("facebook/chameleon-7b", "mp"),
2325
])
2426
@fork_new_process_for_each_test
2527
def test_models(hf_runner, vllm_runner, image_assets, model: str,
@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
3436
from ..models.test_llava import models, run_test
3537
elif model.startswith("llava-hf/llava-v1.6"):
3638
from ..models.test_llava_next import models, run_test
39+
elif model.startswith("facebook/chameleon"):
40+
from ..models.test_chameleon import models, run_test
3741
else:
3842
raise NotImplementedError(f"Unsupported model: {model}")
3943

tests/entrypoints/openai/test_oot_registration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import time
3+
from typing import Optional
34

45
import torch
56
from openai import OpenAI, OpenAIError
@@ -17,8 +18,11 @@
1718

1819
class MyOPTForCausalLM(OPTForCausalLM):
1920

20-
def compute_logits(self, hidden_states: torch.Tensor,
21-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
21+
def compute_logits(
22+
self,
23+
hidden_states: torch.Tensor,
24+
sampling_metadata: SamplingMetadata,
25+
) -> Optional[torch.Tensor]:
2226
# this dummy model always predicts the first token
2327
logits = super().compute_logits(hidden_states, sampling_metadata)
2428
logits.zero_()

tests/models/test_chameleon.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import re
21
from typing import List, Optional, Type
32

43
import pytest
4+
from transformers import BatchEncoding
55

66
from vllm.multimodal.utils import rescale_image_size
7+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
78

8-
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
9+
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
10+
from .utils import check_outputs_equal
911

1012
pytestmark = pytest.mark.vlm
1113

@@ -19,23 +21,29 @@
1921
models = ["facebook/chameleon-7b"]
2022

2123

22-
#TODO (ywang96): Add correctness test when chameleon is
23-
# available on transformers.
2424
def run_test(
25+
hf_runner: Type[HfRunner],
2526
vllm_runner: Type[VllmRunner],
2627
image_assets: _ImageAssets,
2728
model: str,
2829
*,
2930
size_factors: List[float],
3031
dtype: str,
3132
max_tokens: int,
33+
num_logprobs: int,
3234
tensor_parallel_size: int,
3335
distributed_executor_backend: Optional[str] = None,
3436
):
35-
"""Test if the model can generate text given
36-
a batch of images and prompts.
37-
37+
"""Inference result should be the same between hf and vllm.
38+
39+
All the image fixtures for the test is under tests/images.
40+
For huggingface runner, we provide the PIL images as input.
41+
For vllm runner, we provide MultiModalDataDict objects
42+
and corresponding vision language config as input.
43+
Note, the text input is also adjusted to abide by vllm contract.
44+
The text output is sanitized to be able to compare with hf.
3845
"""
46+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
3947
images = [asset.pil_image for asset in image_assets]
4048

4149
inputs_per_image = [(
@@ -50,35 +58,49 @@ def run_test(
5058
distributed_executor_backend=distributed_executor_backend,
5159
enforce_eager=True) as vllm_model:
5260

53-
for prompts, images in inputs_per_image:
54-
vllm_outputs = vllm_model.generate_greedy(prompts,
55-
max_tokens,
56-
images=images)
57-
for i in range(len(vllm_outputs)):
58-
59-
# format prompt back to original
60-
replacements = {
61-
"<racm3:break>": "",
62-
"<eoss>": "",
63-
"<reserved08706>": ""
64-
}
65-
pattern = '|'.join(replacements.keys())
66-
vllm_result = re.sub(
67-
pattern,
68-
lambda match: replacements[match.group(0)], #noqa B023
69-
vllm_outputs[i][1])
70-
vllm_result = vllm_result.replace("<image>", "", 1023)
71-
assert vllm_result[:len(prompts[i])] == prompts[i]
72-
73-
# assert at least 10 new characters are generated
74-
# (to take stop token into account)
75-
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
61+
vllm_outputs_per_image = [
62+
vllm_model.generate_greedy_logprobs(prompts,
63+
max_tokens,
64+
num_logprobs=num_logprobs,
65+
images=images)
66+
for prompts, images in inputs_per_image
67+
]
68+
69+
def process(hf_inputs: BatchEncoding):
70+
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
71+
.to(torch_dtype) # type: ignore
72+
return hf_inputs
73+
74+
with hf_runner(model,
75+
dtype=dtype,
76+
postprocess_inputs=process,
77+
is_vision_model=True) as hf_model:
78+
hf_outputs_per_image = [
79+
hf_model.generate_greedy_logprobs_limit(prompts,
80+
max_tokens,
81+
num_logprobs=num_logprobs,
82+
images=images)
83+
for prompts, images in inputs_per_image
84+
]
85+
86+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
87+
vllm_outputs_per_image):
88+
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
89+
# compare them
90+
check_outputs_equal(
91+
outputs_0_lst=[outputs[:2] for outputs in hf_outputs],
92+
outputs_1_lst=[outputs[:2] for outputs in vllm_outputs],
93+
name_0="hf",
94+
name_1="vllm",
95+
)
7696

7797

7898
@pytest.mark.parametrize("model", models)
7999
@pytest.mark.parametrize(
80100
"size_factors",
81101
[
102+
# No image
103+
[],
82104
# Single-scale
83105
[1.0],
84106
# Single-scale, batched
@@ -88,15 +110,18 @@ def run_test(
88110
],
89111
)
90112
@pytest.mark.parametrize("dtype", ["bfloat16"])
91-
@pytest.mark.parametrize("max_tokens", [128])
92-
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
93-
max_tokens: int) -> None:
113+
@pytest.mark.parametrize("max_tokens", [8])
114+
@pytest.mark.parametrize("num_logprobs", [5])
115+
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
116+
dtype, max_tokens, num_logprobs) -> None:
94117
run_test(
118+
hf_runner,
95119
vllm_runner,
96120
image_assets,
97121
model,
98122
size_factors=size_factors,
99123
dtype=dtype,
100124
max_tokens=max_tokens,
125+
num_logprobs=num_logprobs,
101126
tensor_parallel_size=1,
102127
)

tests/models/test_llava.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import List, Optional, Tuple, Type
22

33
import pytest
4-
from transformers import AutoConfig, AutoTokenizer
4+
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
55

66
from vllm.multimodal.utils import rescale_image_size
77
from vllm.sequence import SampleLogprobs
@@ -110,16 +110,21 @@ def run_test(
110110
for prompts, images in inputs_per_image
111111
]
112112

113-
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
114-
if mantis_processor is not None:
113+
if mantis_processor is not None:
115114

116-
def process(*args, **kwargs):
117-
output = mantis_processor(*args, **kwargs)
118-
output["pixel_values"] = output["pixel_values"].to(torch_dtype)
119-
return output
115+
def process(hf_inputs: BatchEncoding):
116+
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
117+
.to(torch_dtype) # type: ignore
118+
return hf_inputs
119+
else:
120120

121-
hf_model.processor = process
121+
def process(hf_inputs: BatchEncoding):
122+
return hf_inputs
122123

124+
with hf_runner(model,
125+
dtype=dtype,
126+
postprocess_inputs=process,
127+
is_vision_model=True) as hf_model:
123128
hf_outputs_per_image = [
124129
hf_model.generate_greedy_logprobs_limit(prompts,
125130
max_tokens,

tests/models/test_minicpmv.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from collections import UserDict
21
from typing import List, Optional, Tuple, Type
32

43
import pytest
54
import torch
65
import torch.types
7-
from transformers import BatchFeature
6+
from transformers import BatchEncoding
87

98
from vllm.multimodal.utils import rescale_image_size
109
from vllm.sequence import SampleLogprobs
@@ -14,18 +13,6 @@
1413

1514
pytestmark = pytest.mark.vlm
1615

17-
18-
class NestedInputs(UserDict):
19-
20-
def __init__(self, model_inputs: BatchFeature):
21-
super().__init__({"model_inputs": model_inputs})
22-
23-
self.model_inputs = model_inputs
24-
25-
def to(self, device: torch.types.Device):
26-
return NestedInputs(self.model_inputs.to(device))
27-
28-
2916
# The image token is placed before "user" on purpose so that the test can pass
3017
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
3118
"stop_sign":
@@ -41,6 +28,10 @@ def to(self, device: torch.types.Device):
4128
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
4229

4330

31+
def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding:
32+
return BatchEncoding({"model_inputs": hf_inputs})
33+
34+
4435
def trunc_hf_output(hf_output: Tuple[List[int], str,
4536
Optional[SampleLogprobs]]):
4637
output_ids, output_str, out_logprobs = hf_output
@@ -105,11 +96,8 @@ def run_test(
10596
for prompts, images in inputs_per_image
10697
]
10798

108-
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
109-
hf_processor = hf_model.processor
110-
hf_model.processor = lambda **kw: NestedInputs(
111-
hf_processor(**kw) # type: ignore
112-
)
99+
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
100+
with hf_model, torch.no_grad():
113101
hf_outputs_per_image = [
114102
hf_model.generate_greedy_logprobs_limit(prompts,
115103
max_tokens,
@@ -224,11 +212,8 @@ def run_multi_image_test(
224212
for prompts, images in inputs_per_case
225213
]
226214

227-
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
228-
hf_processor = hf_model.processor
229-
hf_model.processor = lambda **kw: NestedInputs(
230-
hf_processor(**kw) # type: ignore
231-
)
215+
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
216+
with hf_model, torch.no_grad():
232217
hf_outputs_per_case = [
233218
hf_model.generate_greedy_logprobs_limit(prompts,
234219
max_tokens,

0 commit comments

Comments
 (0)