Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
return False

# No Embedding Models so far.
if model_config.task not in ["generate"]:
if model_config.task not in ["generate", "embed", "classify", "score", "reward"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def __init__(
**kwargs,
)

logger.info(f"Engine args: {engine_args}")
# Create the Engine (autoselects V0 vs V1)
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request

Expand All @@ -26,6 +27,7 @@ class NewRequestData:
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
pooling_params: PoolingParams
block_ids: list[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
Expand All @@ -43,6 +45,7 @@ def from_request(
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
Expand Down
16 changes: 16 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def update_from_output(
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
hidden_states = model_runner_output.hidden_states
num_scheduled_tokens = scheduler_output.num_scheduled_tokens

new_running: list[Request] = []
Expand All @@ -663,6 +664,21 @@ def update_from_output(
new_running.append(request)
continue

if hidden_states is not None:
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
new_logprobs=None,
new_prompt_logprobs_tensors=None,
stop_reason=request.stop_reason,
events=request.take_events(),
hidden_states=hidden_states))
continue

req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from typing import Any, Optional, Union

import msgspec
import torch

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
Expand Down Expand Up @@ -53,6 +55,7 @@ class EngineCoreRequest(
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
pooling_params: Optional[PoolingParams]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
Expand Down Expand Up @@ -106,6 +109,8 @@ class EngineCoreOutput(
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None

hidden_states: Optional[torch.Tensor] = None

@property
def finished(self) -> bool:
return self.finish_reason is not None
Expand Down
68 changes: 61 additions & 7 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -212,19 +212,19 @@ async def add_request(
if self.errored:
raise EngineDeadError()

assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
# assert isinstance(params, SamplingParams), \
# "Pooling is not supported in V1"

# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
queue = RequestOutputCollector(output_kind=RequestOutputKind.CUMULATIVE)

# Convert Input --> Request.
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)

if params.n == 1:
if isinstance(params, PoolingParams) or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue

Expand Down Expand Up @@ -425,7 +425,7 @@ def _record_stats(
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)

def encode(
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
Expand All @@ -434,7 +434,61 @@ def encode(
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()

q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=None,
priority=priority,
)

# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()

# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out

# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise

# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise

# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise

# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e

async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# assert vllm_config.model_config.runner_type != "pooling"

logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
Expand Down
Loading