Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[BugFix] Prevent LLM.encode for non-generation Models (vllm-project…
Browse files Browse the repository at this point in the history
…#5184)

Co-authored-by: mgoin <michael@neuralmagic.com>
  • Loading branch information
robertgshaw2-neuralmagic and mgoin committed Jun 11, 2024
1 parent fd82eff commit 5b6b8ed
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ def generate(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down Expand Up @@ -429,6 +434,11 @@ def encode(
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)

if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down

0 comments on commit 5b6b8ed

Please sign in to comment.