Skip to content

Commit 91d13b2

Browse files
author
sixiang-google
committed
vllm engine core fix contd
1 parent 299d667 commit 91d13b2

File tree

2 files changed

+24
-26
lines changed

2 files changed

+24
-26
lines changed

tests/core/test_core_tpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,14 @@ def test_add_request(self):
143143
mock_engine_request.mm_inputs = []
144144
mock_engine_request.use_structured_output = False
145145
mock_engine_request.kv_transfer_params = None
146+
mock_engine_request.pooling_params = None
147+
mock_engine_request.sampling_params.guided_decoding = None
146148

147149
# Mock the prefill engine's scheduler
148150
mock_prefill_scheduler = self.mock_prefill_engine_instance.scheduler
149151

150152
# Call the method under test
153+
mock_engine_request, _ = proc.preprocess_add_request(mock_engine_request)
151154
proc.add_request(mock_engine_request)
152155

153156
# Assert that the request was added to the first prefill engine's scheduler

tpu_commons/core/core_tpu.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -229,40 +229,35 @@ def _create_engine_cores(
229229

230230
return engine_cores
231231

232-
def _add_request(self, request: EngineCoreRequest) -> Request:
233-
if request.mm_hashes is not None:
234-
# Here, if hash exists for a multimodal input, then it will be
235-
# fetched from the cache, else it will be added to the cache.
236-
# Note that the cache here is mirrored with the client cache, so
237-
# anything that has a hash must have a HIT cache entry here
238-
# as well.
239-
assert request.mm_inputs is not None
240-
request.mm_inputs = self._prefill_engines[
241-
0].mm_input_cache_server.get_and_update_p1(
242-
request.mm_inputs, request.mm_hashes)
243-
244-
req = Request.from_engine_core_request(request)
245-
246-
if req.use_structured_output:
247-
# Start grammar compilation asynchronously
248-
self._prefill_engines[0].structured_output_manager.grammar_init(
249-
req)
250-
251-
return req
252-
253-
def add_request(self, request: EngineCoreRequest):
254-
vllm_request = self._add_request(request)
232+
233+
def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
234+
# vllm_request = self._add_request(request)
255235

256236
# TODO(fhzhang): support multiple prefill engines.
257-
self._prefill_engines[0].scheduler.add_request(vllm_request)
258-
self._requests[request.request_id] = vllm_request
237+
if not isinstance(request.request_id, str):
238+
raise TypeError(
239+
f"request_id must be a string, got {type(request.request_id)}")
240+
241+
if pooling_params := request.pooling_params:
242+
supported_pooling_tasks = [
243+
task for task in self.get_supported_tasks()
244+
if task in POOLING_TASKS
245+
]
246+
247+
if pooling_params.task not in supported_pooling_tasks:
248+
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
249+
f"Supported tasks: {supported_pooling_tasks}")
250+
251+
self._prefill_engines[0].scheduler.add_request(request)
252+
self._requests[request.request_id] = request
259253

260254
def _handle_client_request(self, request_type: EngineCoreRequestType,
261255
request: Any) -> None:
262256
"""Dispatch request from client."""
263257

264258
if request_type == EngineCoreRequestType.ADD:
265-
self.add_request(request)
259+
req, request_wave = request
260+
self.add_request(req)
266261
elif request_type == EngineCoreRequestType.ABORT:
267262
# TODO(fhzhang): we need to keep track of which engine is processing
268263
# the request and finish it there.

0 commit comments

Comments
 (0)