Skip to content

Commit 6a470be

Browse files
author
sixiang-google
committed
vllm engine core fix contd
1 parent 299d667 commit 6a470be

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

tests/core/test_core_tpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,15 @@ def test_add_request(self):
143143
mock_engine_request.mm_inputs = []
144144
mock_engine_request.use_structured_output = False
145145
mock_engine_request.kv_transfer_params = None
146+
mock_engine_request.pooling_params = None
147+
mock_engine_request.sampling_params.guided_decoding = None
146148

147149
# Mock the prefill engine's scheduler
148150
mock_prefill_scheduler = self.mock_prefill_engine_instance.scheduler
149151

150152
# Call the method under test
153+
mock_engine_request, _ = proc.preprocess_add_request(
154+
mock_engine_request)
151155
proc.add_request(mock_engine_request)
152156

153157
# Assert that the request was added to the first prefill engine's scheduler

tpu_commons/core/core_tpu.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.v1.engine.core import EngineCore as vLLMEngineCore
2020
from vllm.v1.engine.core import EngineCoreProc as vLLMEngineCoreProc
2121
from vllm.v1.request import Request, RequestStatus
22+
from vllm.tasks import POOLING_TASKS
2223

2324
from tpu_commons.core import disagg_executor, disagg_utils
2425
from tpu_commons.runner.utils import LatencyTracker
@@ -229,40 +230,34 @@ def _create_engine_cores(
229230

230231
return engine_cores
231232

232-
def _add_request(self, request: EngineCoreRequest) -> Request:
233-
if request.mm_hashes is not None:
234-
# Here, if hash exists for a multimodal input, then it will be
235-
# fetched from the cache, else it will be added to the cache.
236-
# Note that the cache here is mirrored with the client cache, so
237-
# anything that has a hash must have a HIT cache entry here
238-
# as well.
239-
assert request.mm_inputs is not None
240-
request.mm_inputs = self._prefill_engines[
241-
0].mm_input_cache_server.get_and_update_p1(
242-
request.mm_inputs, request.mm_hashes)
233+
def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
234+
# vllm_request = self._add_request(request)
243235

244-
req = Request.from_engine_core_request(request)
245-
246-
if req.use_structured_output:
247-
# Start grammar compilation asynchronously
248-
self._prefill_engines[0].structured_output_manager.grammar_init(
249-
req)
236+
# TODO(fhzhang): support multiple prefill engines.
237+
if not isinstance(request.request_id, str):
238+
raise TypeError(
239+
f"request_id must be a string, got {type(request.request_id)}")
250240

251-
return req
241+
if pooling_params := request.pooling_params:
242+
supported_pooling_tasks = [
243+
task for task in self.get_supported_tasks()
244+
if task in POOLING_TASKS
245+
]
252246

253-
def add_request(self, request: EngineCoreRequest):
254-
vllm_request = self._add_request(request)
247+
if pooling_params.task not in supported_pooling_tasks:
248+
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
249+
f"Supported tasks: {supported_pooling_tasks}")
255250

256-
# TODO(fhzhang): support multiple prefill engines.
257-
self._prefill_engines[0].scheduler.add_request(vllm_request)
258-
self._requests[request.request_id] = vllm_request
251+
self._prefill_engines[0].scheduler.add_request(request)
252+
self._requests[request.request_id] = request
259253

260254
def _handle_client_request(self, request_type: EngineCoreRequestType,
261255
request: Any) -> None:
262256
"""Dispatch request from client."""
263257

264258
if request_type == EngineCoreRequestType.ADD:
265-
self.add_request(request)
259+
req, request_wave = request
260+
self.add_request(req)
266261
elif request_type == EngineCoreRequestType.ABORT:
267262
# TODO(fhzhang): we need to keep track of which engine is processing
268263
# the request and finish it there.

0 commit comments

Comments
 (0)