Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/core/test_core_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,15 @@ def test_add_request(self):
mock_engine_request.mm_inputs = []
mock_engine_request.use_structured_output = False
mock_engine_request.kv_transfer_params = None
mock_engine_request.pooling_params = None
mock_engine_request.sampling_params.guided_decoding = None

# Mock the prefill engine's scheduler
mock_prefill_scheduler = self.mock_prefill_engine_instance.scheduler

# Call the method under test
mock_engine_request, _ = proc.preprocess_add_request(
mock_engine_request)
proc.add_request(mock_engine_request)

# Assert that the request was added to the first prefill engine's scheduler
Expand Down
43 changes: 19 additions & 24 deletions tpu_commons/core/core_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import jax
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.tasks import POOLING_TASKS
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput,
UtilityResult)
Expand Down Expand Up @@ -229,40 +230,34 @@ def _create_engine_cores(

return engine_cores

def _add_request(self, request: EngineCoreRequest) -> Request:
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self._prefill_engines[
0].mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
def add_request(self, request: EngineCoreRequest, request_wave: int = 0):
# vllm_request = self._add_request(request)

req = Request.from_engine_core_request(request)

if req.use_structured_output:
# Start grammar compilation asynchronously
self._prefill_engines[0].structured_output_manager.grammar_init(
req)
# TODO(fhzhang): support multiple prefill engines.
if not isinstance(request.request_id, str):
raise TypeError(
f"request_id must be a string, got {type(request.request_id)}")

return req
if pooling_params := request.pooling_params:
supported_pooling_tasks = [
task for task in self.get_supported_tasks()
if task in POOLING_TASKS
]

def add_request(self, request: EngineCoreRequest):
vllm_request = self._add_request(request)
if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")

# TODO(fhzhang): support multiple prefill engines.
self._prefill_engines[0].scheduler.add_request(vllm_request)
self._requests[request.request_id] = vllm_request
self._prefill_engines[0].scheduler.add_request(request)
self._requests[request.request_id] = request

def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""

if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
req, request_wave = request
self.add_request(req)
elif request_type == EngineCoreRequestType.ABORT:
# TODO(fhzhang): we need to keep track of which engine is processing
# the request and finish it there.
Expand Down