Skip to content

[Bugfix] Fix auto dtype casting for BatchFeature #19316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger

Expand Down Expand Up @@ -107,7 +108,8 @@ async def test_load(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
Comment on lines +111 to +112
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that disable openmp by setting torch_num_threads=1 during engine forking can fix the deadlock issue locally. Let's see what's the CI going on then.

after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -154,7 +156,8 @@ async def test_abort(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -226,7 +229,8 @@ async def test_finished_flag(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)

await engine.do_log_stats()
Expand All @@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)

sampling_params = SamplingParams(max_tokens=100,
Expand Down
22 changes: 13 additions & 9 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor
Expand Down Expand Up @@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""

# First request.
Expand Down Expand Up @@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
Expand Down Expand Up @@ -286,9 +289,10 @@ def shutdown(self):
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None

# Add two requests in a row. Each request have 12 prompt tokens.
Expand Down
63 changes: 35 additions & 28 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
Expand Down Expand Up @@ -138,13 +139,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)

MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
Expand Down Expand Up @@ -223,13 +226,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

try:
MAX_TOKENS = 20
Expand Down Expand Up @@ -312,13 +317,14 @@ def test_kv_cache_events(
UsageContext.UNKNOWN_CONTEXT)

executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
Expand Down Expand Up @@ -394,13 +400,14 @@ async def test_kv_cache_events_dp(
UsageContext.UNKNOWN_CONTEXT)

executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)

# Build endpoints for all DP ranks
Expand Down
4 changes: 3 additions & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def maybe_cast_dtype(x):
try:
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
cast_output = json_map_leaves(maybe_cast_dtype, output)
if isinstance(output, BatchFeature):
cast_output = json_map_leaves(maybe_cast_dtype, output.data)
return BatchFeature(cast_output)

cast_output = json_map_leaves(maybe_cast_dtype, output)

logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,9 +965,9 @@ def _process_image_input(
grid_thw_list = grid_thw.tolist()

if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
image_embeds = image_input["image_embeds"]
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)

# Split concatenated embeddings for each image item.
Expand All @@ -985,10 +985,9 @@ def _process_video_input(
grid_thw_list = grid_thw.tolist()

if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
video_embeds = video_input["video_embeds"]
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
pixel_values_videos = video_input["pixel_values_videos"]
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw_list)

Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,9 +1208,9 @@ def _process_image_input(
assert grid_thw.ndim == 2

if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
image_embeds = image_input["image_embeds"]
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
pixel_values = image_input["pixel_values"]
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

# Split concatenated embeddings for each image item.
Expand All @@ -1226,10 +1226,9 @@ def _process_video_input(
assert grid_thw.ndim == 2

if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
video_embeds = video_input["video_embeds"]
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
pixel_values_videos = video_input["pixel_values_videos"]
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

# Split concatenated embeddings for each video item.
Expand Down
10 changes: 10 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@
torch.int64: np.int64,
}


@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
"""Sets the default number of threads for PyTorch to the given value."""
old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)


P = ParamSpec('P')
T = TypeVar("T")
U = TypeVar("U")
Expand Down