Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ class ModelConfig:
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
process_hidden_states: Optional[bool] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can never be None, we shouldn't hint it as Optional

Suggested change
process_hidden_states: Optional[bool] = False
process_hidden_states: bool = False

"""Extract the hidden states of the model to be processed before the request
is completed. This is so far only supported for embedding/pooling models """

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -4820,7 +4823,8 @@ def __str__(self):
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")
f"compilation_config={self.compilation_config!r}"
f"process_hidden_states={self.model_config.process_hidden_states}")


_current_vllm_config: Optional[VllmConfig] = None
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \
MultiModalConfig.disable_mm_preprocessor_cache
process_hidden_states: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, the default is only defined in one place

Suggested change
process_hidden_states: bool = False
process_hidden_states: bool = ModelConfig.process_hidden_states

# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
Expand Down Expand Up @@ -503,6 +504,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
model_group.add_argument("--process-hidden-states",
**model_kwargs["process_hidden_states"])
# This one is a special case because it is the
# opposite of ModelConfig.use_async_output_proc
model_group.add_argument(
Expand Down Expand Up @@ -910,6 +913,7 @@ def create_model_config(self) -> ModelConfig:
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
process_hidden_states=self.process_hidden_states,
)

def validate_tensorizer_args(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
process_hidden_states: If True, it loads the hidden states processor
and to process the hiddne states for each request before returning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: hiddne

to the user.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].

Note:
Expand Down Expand Up @@ -195,6 +198,7 @@ def __init__(
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
process_hidden_states: bool = False,
**kwargs,
) -> None:
"""LLM constructor."""
Expand Down Expand Up @@ -268,6 +272,7 @@ def __init__(
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
process_hidden_states=process_hidden_states,
**kwargs,
)

Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,11 @@ def get_vllm_port() -> Optional[int]:
# The default value is "VLLM".
"VLLM_PROCESS_NAME_PREFIX":
lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"),
# Controls which hidden states processor plugin to load.
# This is used when more than a hidden states processor is installed
# to decide which one to use.
"VLLM_USE_HIDDEN_STATES_PROCESSOR":
lambda: os.getenv("VLLM_USE_HIDDEN_STATES_PROCESSOR", None),
}

# --8<-- [end:env-vars-definition]
Expand Down
6 changes: 5 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ class PoolingOutput:
data: The extracted hidden states.
"""
data: torch.Tensor
processed_hidden_states: Optional[Any] = None

def __repr__(self) -> str:
return (f"PoolingOutput(data={self.data})")
hidden_states = ("None" if not self.processed_hidden_states else type(
self.processed_hidden_states).__name__)
return (f"PoolingOutput(data={self.data}"
f"Processed hidden states={hidden_states})")

def __eq__(self, other: object) -> bool:
return (isinstance(other, self.__class__) and bool(
Expand Down
82 changes: 82 additions & 0 deletions vllm/plugins/hidden_states_processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import logging
from typing import Optional

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
from vllm.plugins.hidden_states_processors.interface import (
HiddenStatesProcessor)
from vllm.utils import resolve_obj_by_qualname

logger = logging.getLogger(__name__)


def identity_hidden_states_processor() -> str:
return ("vllm.plugins.hidden_states_processors."
"default.IdentityHiddenStatesProcessor")


default_hidden_states_processors = {
"identity": identity_hidden_states_processor
}


def get_hidden_states_processor(
vllm_config: VllmConfig) -> Optional["HiddenStatesProcessor"]:
# hidden states processors are loaded as plugins under the
# 'vllm.hidden_state_processor_plugins group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
# All hidden state plugins implement the HiddenStatesProcessor class

hidden_states_processor_plugins = \
load_plugins_by_group('vllm.hidden_states_processor_plugins')

available_plugins = {
**default_hidden_states_processors,
**hidden_states_processor_plugins
}

loadable_plugins = {}
for name, func in available_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
pass
Comment on lines +50 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The broad except Exception: pass will silently ignore any errors during plugin loading, which can make debugging very difficult. For instance, if a plugin's entry point has a bug, it will be skipped without any notification. It's better to log the exception to provide visibility into loading failures.

Suggested change
except Exception:
pass
except Exception:
logger.warning("Failed to load hidden states processor plugin '%s'.",
name,
exc_info=True)


num_available_plugins = len(loadable_plugins.keys())

# Just a sanity check to make sure we are not
# messing up with the available plugins
assert num_available_plugins > 0

if num_available_plugins > 1 and envs.VLLM_USE_HIDDEN_STATES_PROCESSOR:
plugin_name = envs.VLLM_USE_HIDDEN_STATES_PROCESSOR
if plugin_name not in loadable_plugins:
raise ValueError(
f"Hidden states processor plugin '{plugin_name}' not found. "
f"Available plugins: {list(loadable_plugins.keys())}")

activated_plugin_cls = loadable_plugins[plugin_name]
activated_plugin_name = envs.VLLM_USE_HIDDEN_STATES_PROCESSOR
else:
activated_plugin_name = list(loadable_plugins.keys())[0]
activated_plugin_cls = loadable_plugins[activated_plugin_name]
if (num_available_plugins > 1
and not envs.VLLM_USE_HIDDEN_STATES_PROCESSOR):
logger.info(
"Multiple hidden states processor plugins available "
"but VLLM_USE_HIDDEN_STATES_PROCESSOR is not pointing "
"to any specific plugins. Loading the first available one.\n"
"Available hidden states "
"processor plugins %s", str(loadable_plugins.keys()))

logger.info("Loaded hidden states processor plugin: %s",
activated_plugin_name)
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
19 changes: 19 additions & 0 deletions vllm/plugins/hidden_states_processors/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

import torch

from vllm.plugins.hidden_states_processors.interface import (
HiddenStatesProcessor)


class IdentityHiddenStatesProcessor(HiddenStatesProcessor):

def apply(self, data: torch.Tensor) -> Any:
"""
This is the default identity hidden states processor
that returns the hidden_states data as is
"""
return data
19 changes: 19 additions & 0 deletions vllm/plugins/hidden_states_processors/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from typing import Any

import torch

from vllm.config import VllmConfig


class HiddenStatesProcessor(ABC):

def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config

@abstractmethod
def apply(self, data: torch.Tensor) -> Any:
...
6 changes: 6 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ def update_from_output(
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
hidden_states = model_runner_output.hidden_states

outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
Expand Down Expand Up @@ -821,6 +822,10 @@ def update_from_output(
else:
stopped_preempted_reqs.add(request)

req_hidden_states = None
if hidden_states:
req_hidden_states = hidden_states[req_index]

# Extract sample logprobs if needed.
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
Expand Down Expand Up @@ -864,6 +869,7 @@ def update_from_output(
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
hidden_states=req_hidden_states,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class EngineCoreOutput(
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None

pooling_output: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None

finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ def __init__(
)

# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
self.output_processor = OutputProcessor(
vllm_config=vllm_config,
tokenizer=self.tokenizer,
log_stats=self.log_stats,
)

# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
mm_registry=mm_registry)

# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
self.output_processor = OutputProcessor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
log_stats=self.log_stats)

# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
Expand Down
35 changes: 29 additions & 6 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import torch

from vllm.config import VllmConfig
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.plugins.hidden_states_processors import get_hidden_states_processor
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -164,6 +166,7 @@ def make_request_output(
self,
new_token_ids: list[int],
pooling_output: Optional[torch.Tensor],
processed_hidden_states: Optional[Any],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
Expand All @@ -179,9 +182,12 @@ def make_request_output(

request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)],
finished)
output = self._new_pooling_output(
pooling_output,
processed_hidden_states=processed_hidden_states)
return self._new_request_output(request_id=request_id,
outputs=[output],
finished=finished)

output = self._new_completion_output(new_token_ids, finish_reason,
stop_reason)
Expand Down Expand Up @@ -266,16 +272,19 @@ def _new_completion_output(
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
processed_hidden_states: Any,
) -> PoolingOutput:

return PoolingOutput(data=pooling_output)
return PoolingOutput(data=pooling_output,
processed_hidden_states=processed_hidden_states)


class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""

def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerGroup,
log_stats: bool,
):
Expand All @@ -284,6 +293,11 @@ def __init__(
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
if vllm_config.model_config.process_hidden_states:
if not (processor := (get_hidden_states_processor(vllm_config))):
raise ValueError(
"Process hidden states is set but no processor plugins")
self.hidden_states_processor = processor

def get_num_unfinished_requests(self):
return len(self.request_states)
Expand Down Expand Up @@ -391,6 +405,7 @@ def process_outputs(
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
hidden_states = engine_core_output.hidden_states
req_state.is_prefilling = False

if pooling_output is None:
Expand All @@ -408,10 +423,18 @@ def process_outputs(
req_state.logprobs_processor.update_from_output(
engine_core_output)

if pooling_output is not None and hidden_states is not None:
# Currently we process hidden states only for pooling models
processed_hidden_states = \
self.hidden_states_processor.apply(hidden_states)
else:
processed_hidden_states = None

# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
new_token_ids, pooling_output, processed_hidden_states,
finish_reason, stop_reason, kv_transfer_params,
num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None

# This is used for pooling models that install a hidden states processor
hidden_states: Optional[list[torch.Tensor]] = None


EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
Expand Down
Loading