Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 30 additions & 91 deletions python/ray/llm/_internal/serve/configs/openai_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,104 +6,43 @@

from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union

from pydantic import (
BaseModel,
ConfigDict,
Field,
)
from pydantic import BaseModel, ConfigDict
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest as vLLMChatCompletionRequest,
ChatCompletionResponse as vLLMChatCompletionResponse,
ChatCompletionStreamResponse as vLLMChatCompletionStreamResponse,
CompletionRequest as vLLMCompletionRequest,
CompletionResponse as vLLMCompletionResponse,
CompletionStreamResponse as vLLMCompletionStreamResponse,
EmbeddingChatRequest as vLLMEmbeddingChatRequest,
EmbeddingCompletionRequest as vLLMEmbeddingCompletionRequest,
EmbeddingResponse as vLLMEmbeddingResponse,
ErrorInfo as vLLMErrorInfo,
ErrorResponse as vLLMErrorResponse,
ScoreRequest as vLLMScoreRequest,
ScoreResponse as vLLMScoreResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionResponse,
CompletionStreamResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingResponse,
ErrorInfo,
ErrorResponse,
ScoreRequest,
ScoreResponse,
)
from vllm.utils import random_uuid

__all__ = [
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionStreamResponse",
"CompletionRequest",
"CompletionResponse",
"CompletionStreamResponse",
"EmbeddingChatRequest",
"EmbeddingCompletionRequest",
"EmbeddingResponse",
"ErrorInfo",
"ErrorResponse",
"ScoreRequest",
"ScoreResponse",
]

if TYPE_CHECKING:
from ray.llm._internal.serve.configs.server_models import LLMConfig


class ChatCompletionRequest(vLLMChatCompletionRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ChatCompletionResponse(vLLMChatCompletionResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ChatCompletionStreamResponse(vLLMChatCompletionStreamResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ErrorInfo(vLLMErrorInfo):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ErrorResponse(vLLMErrorResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


# TODO (Kourosh): Upstream
class CompletionRequest(vLLMCompletionRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)

request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)


class CompletionResponse(vLLMCompletionResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


class CompletionStreamResponse(vLLMCompletionStreamResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


# TODO (Kourosh): Upstream
class EmbeddingCompletionRequest(vLLMEmbeddingCompletionRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)

request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)


class EmbeddingChatRequest(vLLMEmbeddingChatRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)


class EmbeddingResponse(vLLMEmbeddingResponse):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ScoreRequest(vLLMScoreRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)


class ScoreResponse(vLLMScoreResponse):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kouroshHakha do you want to leave these here?

model_config = ConfigDict(arbitrary_types_allowed=True)


EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

LLMEmbeddingsResponse = Union[
Expand Down
20 changes: 16 additions & 4 deletions python/ray/llm/_internal/serve/deployments/llm/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import TYPE_CHECKING, AsyncGenerator, Union
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union

from ray.llm._internal.serve.configs.server_models import (
DiskMultiplexConfig,
Expand All @@ -16,6 +16,9 @@
EmbeddingResponse,
ErrorResponse,
)
from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import (
RawRequestInfo,
)


class LLMEngine(abc.ABC):
Expand All @@ -42,7 +45,9 @@ async def reset_prefix_cache(self) -> None:

@abc.abstractmethod
async def chat(
self, request: "ChatCompletionRequest"
self,
request: "ChatCompletionRequest",
raw_request_info: Optional["RawRequestInfo"] = None,
) -> AsyncGenerator[Union[str, "ChatCompletionResponse", "ErrorResponse"], None]:
"""Run a ChatCompletion with the engine.

Expand All @@ -56,6 +61,7 @@ async def chat(

Args:
request: The chat completion request.
raw_request_info: Optional RawRequestInfo containing headers and state from the original request.

Yields:
Union[str, ChatCompletionResponse, ErrorResponse]: A string representing a chunk of the response, a ChatCompletionResponse object, or an ErrorResponse object.
Expand All @@ -67,7 +73,9 @@ async def chat(

@abc.abstractmethod
async def completions(
self, request: "CompletionRequest"
self,
request: "CompletionRequest",
raw_request_info: Optional["RawRequestInfo"] = None,
) -> AsyncGenerator[Union[str, "CompletionResponse", "ErrorResponse"], None]:
"""Run a Completion with the engine.

Expand All @@ -85,6 +93,7 @@ async def completions(

Args:
request: The completion request.
raw_request_info: Optional RawRequestInfo containing headers and state from the original request.

Yields:
Union[str, CompletionResponse, ErrorResponse]: A string
Expand All @@ -98,7 +107,9 @@ async def completions(

@abc.abstractmethod
async def embeddings(
self, request: "EmbeddingRequest"
self,
request: "EmbeddingRequest",
raw_request_info: Optional["RawRequestInfo"] = None,
) -> AsyncGenerator[Union["EmbeddingResponse", "ErrorResponse"], None]:
"""Run an Embedding with the engine.

Expand All @@ -112,6 +123,7 @@ async def embeddings(

Args:
request: The embedding request.
raw_request_info: Optional RawRequestInfo containing headers and state from the original request.

Returns:
An async generator that yields EmbeddingResponse objects or ErrorResponse objects, and returns None when the generator is done.
Expand Down
47 changes: 38 additions & 9 deletions python/ray/llm/_internal/serve/deployments/llm/llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Union,
)

from fastapi import Request

import ray
from ray import serve
from ray._common.utils import import_attr
Expand All @@ -27,7 +29,10 @@
LLMConfig,
)
from ray.llm._internal.serve.deployments.llm.llm_engine import LLMEngine
from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import VLLMEngine
from ray.llm._internal.serve.deployments.llm.vllm.vllm_engine import (
RawRequestInfo,
VLLMEngine,
)
from ray.llm._internal.serve.deployments.protocol import LLMServerProtocol
from ray.llm._internal.serve.deployments.utils.batcher import Batcher
from ray.llm._internal.serve.deployments.utils.server_utils import (
Expand Down Expand Up @@ -287,13 +292,15 @@ async def _run_request(
*,
engine_method: str,
batch_output_stream: bool = False,
raw_request: Optional[Request] = None,
) -> AsyncGenerator[Any, None]:
"""Run the engine method on the request + perform batching when stream=True.

Args:
request: The request to run.
engine_method: The method to call on the engine.
batch_output_stream: Whether to batch the output stream.
raw_request: The raw FastAPI request object.

Returns:
An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the non-streaming response from engine directly.
Expand All @@ -302,25 +309,36 @@ async def _run_request(
await self._maybe_add_request_id_to_request(request)
await self._maybe_resolve_lora_from_multiplex()

# Extract serializable request info from raw_request
raw_request_info = None
if raw_request is not None:
raw_request_info = RawRequestInfo(
headers=dict(raw_request.headers.items()),
state=dict(raw_request.state._state)
if hasattr(raw_request.state, "_state")
else None,
)

is_stream = hasattr(request, "stream") and request.stream
if is_stream and batch_output_stream:
stream = self._batch_output_stream(
getattr(self.engine, engine_method)(request)
getattr(self.engine, engine_method)(request, raw_request_info)
)
else:
stream = getattr(self.engine, engine_method)(request)
stream = getattr(self.engine, engine_method)(request, raw_request_info)

return stream

async def chat(
self, request: "ChatCompletionRequest"
self, request: "ChatCompletionRequest", raw_request: Optional[Request] = None
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "ChatCompletionResponse"], None
]:
"""Runs a chat request to the LLM engine and returns the response.

Args:
request: A ChatCompletionRequest object.
raw_request: The raw FastAPI request object.

Returns:
An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of chat streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the ChatCompletionResponse object directly.
Expand All @@ -329,17 +347,19 @@ async def chat(
request,
engine_method="chat",
batch_output_stream=True,
raw_request=raw_request,
)

async def completions(
self, request: "CompletionRequest"
self, request: "CompletionRequest", raw_request: Optional[Request] = None
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "CompletionResponse"], None
]:
"""Runs a completion request to the LLM engine and returns the response.

Args:
request: A CompletionRequest object.
raw_request: The raw FastAPI request object.

Returns:
An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of completion streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the CompletionResponse object directly.
Expand All @@ -348,42 +368,51 @@ async def completions(
request,
engine_method="completions",
batch_output_stream=True,
raw_request=raw_request,
)

async def embeddings(
self, request: "EmbeddingRequest"
self, request: "EmbeddingRequest", raw_request: Optional[Request] = None
) -> AsyncGenerator[Union[List["ErrorResponse"], "EmbeddingResponse"], None]:
"""Runs an embeddings request to the engine and returns the response.

Returns an AsyncGenerator over the EmbeddingResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, and embeddings.

Args:
request: An EmbeddingRequest object.
raw_request: The raw FastAPI request object.

Returns:
An AsyncGenerator over the EmbeddingResponse object.
"""
# NOTE: Embeddings does not need batching.
return await self._run_request(
request, engine_method="embeddings", batch_output_stream=False
request,
engine_method="embeddings",
batch_output_stream=False,
raw_request=raw_request,
)

async def score(
self, request: "ScoreRequest"
self, request: "ScoreRequest", raw_request: Optional[Request] = None
) -> AsyncGenerator[Union["ScoreResponse", "ErrorResponse"], None]:
"""Runs a score request to the engine and returns the response.

Returns an AsyncGenerator over the ScoreResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings, and score.

Args:
request: A ScoreRequest object.
raw_request: The raw FastAPI request object.

Returns:
An AsyncGenerator over the ScoreResponse object.
"""
# NOTE: Score does not need batching, similar to embeddings.
return await self._run_request(
request, engine_method="score", batch_output_stream=False
request,
engine_method="score",
batch_output_stream=False,
raw_request=raw_request,
)

async def check_health(self) -> None:
Expand Down
Loading