diff --git a/.github/workflows/nix_tests.yaml b/.github/workflows/nix_tests.yaml index 06768a7becb..f2209f8a453 100644 --- a/.github/workflows/nix_tests.yaml +++ b/.github/workflows/nix_tests.yaml @@ -38,4 +38,4 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} - name: Rust tests. - run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L + run: nix develop .#test --command cargo test diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5ad0fd6a28d..5f00180c537 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -42,6 +42,7 @@ jobs: sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET - name: Install run: | + sudo apt update sudo apt install python3.11-dev -y make install-cpu - name: Run server tests diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml index 13b80497ea0..fb02c00fa22 100644 --- a/.redocly.lint-ignore.yaml +++ b/.redocly.lint-ignore.yaml @@ -23,9 +23,11 @@ docs/openapi.json: - '#/components/schemas/GenerateResponse/properties/details/nullable' - '#/components/schemas/StreamResponse/properties/details/nullable' - '#/components/schemas/ChatRequest/properties/response_format/nullable' + - '#/components/schemas/ChatRequest/properties/stream_options/nullable' - '#/components/schemas/ChatRequest/properties/tool_choice/nullable' - '#/components/schemas/ToolChoice/nullable' - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' + - '#/components/schemas/ChatCompletionChunk/properties/usage/nullable' - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' no-invalid-media-type-examples: - '#/paths/~1/post/responses/422/content/application~1json/example' diff --git a/Cargo.toml b/Cargo.toml index a50bba24e5f..ffd45f16a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,8 @@ members = [ "backends/grpc-metadata", "backends/trtllm", "backends/client", - "launcher" + "launcher", + "router" ] default-members = [ "benchmark", @@ -13,7 +14,8 @@ default-members = [ "backends/grpc-metadata", # "backends/trtllm", "backends/client", - "launcher" + "launcher", + "router" ] resolver = "2" diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index e36dd470230..f7f823fc5da 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel): # Log probabilities for the chat completion logprobs: Optional[Any] # Reason for completion - finish_reason: str + finish_reason: Optional[str] # Usage details of the chat completion usage: Optional[Any] = None @@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel): model: str system_fingerprint: str choices: List[Choice] + usage: Optional[Any] = None class Parameters(BaseModel): diff --git a/docs/openapi.json b/docs/openapi.json index 691705f28ba..f8de656472b 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -742,6 +742,14 @@ }, "system_fingerprint": { "type": "string" + }, + "usage": { + "allOf": [ + { + "$ref": "#/components/schemas/Usage" + } + ], + "nullable": true } } }, @@ -937,6 +945,14 @@ "stream": { "type": "boolean" }, + "stream_options": { + "allOf": [ + { + "$ref": "#/components/schemas/StreamOptions" + } + ], + "nullable": true + }, "temperature": { "type": "number", "format": "float", @@ -1912,6 +1928,19 @@ } } }, + "StreamOptions": { + "type": "object", + "required": [ + "include_usage" + ], + "properties": { + "include_usage": { + "type": "boolean", + "description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.", + "example": "true" + } + } + }, "StreamResponse": { "type": "object", "required": [ diff --git a/flake.lock b/flake.lock index a61907890d8..8d0f4070c3c 100644 --- a/flake.lock +++ b/flake.lock @@ -479,11 +479,11 @@ "systems": "systems_6" }, "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "lastModified": 1726560853, + "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=", "owner": "numtide", "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a", "type": "github" }, "original": { @@ -853,11 +853,11 @@ ] }, "locked": { - "lastModified": 1726280639, - "narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=", + "lastModified": 1726626348, + "narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "e9f8641c92f26fd1e076e705edb12147c384171d", + "rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 3d349ff23ba..07348e7448f 100644 --- a/flake.nix +++ b/flake.nix @@ -67,31 +67,38 @@ ''; }; server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; + client = pkgs.python3.pkgs.callPackage ./nix/client.nix { }; in { checks = { - rust = with pkgs; rustPlatform.buildRustPackage { - name = "rust-checks"; - src = ./.; - cargoLock = { - lockFile = ./Cargo.lock; + rust = + with pkgs; + rustPlatform.buildRustPackage { + name = "rust-checks"; + src = ./.; + cargoLock = { + lockFile = ./Cargo.lock; + }; + buildInputs = [ openssl.dev ]; + nativeBuildInputs = [ + clippy + pkg-config + protobuf + python3 + rustfmt + ]; + buildPhase = '' + cargo check + ''; + checkPhase = '' + cargo fmt -- --check + cargo test -j $NIX_BUILD_CORES + cargo clippy + ''; + installPhase = "touch $out"; }; - buildInputs = [ openssl.dev ]; - nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ]; - buildPhase = '' - cargo check - ''; - checkPhase = '' - cargo fmt -- --check - cargo test -j $NIX_BUILD_CORES - cargo clippy - ''; - installPhase = "touch $out"; - } ; }; - formatter = pkgs.nixfmt-rfc-style; - devShells = with pkgs; rec { default = pure; @@ -106,10 +113,11 @@ test = mkShell { buildInputs = [ - # benchmark - # launcher - # router + benchmark + launcher + router server + client openssl.dev pkg-config cargo diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json new file mode 100644 index 00000000000..8c7be4cb1ec --- /dev/null +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json @@ -0,0 +1,206 @@ +[ + { + "choices": [ + { + "delta": { + "content": "**", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "Deep", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " Learning", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": ":", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " An", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": " Overview", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656043, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "**\n", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "================================", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "=====", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": null, + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": null + }, + { + "choices": [ + { + "delta": { + "content": "\n\n", + "role": "assistant", + "tool_calls": null + }, + "finish_reason": "length", + "index": 0, + "logprobs": null + } + ], + "created": 1726656044, + "id": "", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "object": "chat.completion.chunk", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 40, + "total_tokens": 50 + } + } +] diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index a3b6651d88f..6c359f1e9ad 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -3,9 +3,7 @@ import json from aiohttp import ClientSession -from text_generation.types import ( - Completion, -) +from text_generation.types import Completion, ChatCompletionChunk @pytest.fixture(scope="module") @@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt( assert response == response_snapshot +@pytest.mark.release +async def test_flash_llama_completion_stream_usage( + flash_llama_completion, response_snapshot +): + url = f"{flash_llama_completion.base_url}/v1/chat/completions" + request = { + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is Deep Learning?", + } + ], + "max_tokens": 10, + "temperature": 0.0, + "stream_options": {"include_usage": True}, + "stream": True, + } + string = "" + chunks = [] + had_usage = False + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(ChatCompletionChunk(**c)) + assert "choices" in c + if len(c["choices"]) == 1: + index = c["choices"][0]["index"] + assert index == 0 + string += c["choices"][0]["delta"]["content"] + + has_usage = c["usage"] is not None + assert not had_usage + if has_usage: + had_usage = True + else: + raise RuntimeError("Expected different payload") + assert had_usage + assert ( + string + == "**Deep Learning: An Overview**\n=====================================\n\n" + ) + assert chunks == response_snapshot + + request = { + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is Deep Learning?", + } + ], + "max_tokens": 10, + "temperature": 0.0, + "stream": True, + } + string = "" + chunks = [] + had_usage = False + async with ClientSession(headers=flash_llama_completion.headers) as session: + async with session.post(url, json=request) as response: + # iterate over the stream + async for chunk in response.content.iter_any(): + # remove "data:" + chunk = chunk.decode().split("\n\n") + # remove "data:" if present + chunk = [c.replace("data:", "") for c in chunk] + # remove empty strings + chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] + # parse json + chunk = [json.loads(c) for c in chunk] + + for c in chunk: + chunks.append(ChatCompletionChunk(**c)) + assert "choices" in c + if len(c["choices"]) == 1: + index = c["choices"][0]["index"] + assert index == 0 + string += c["choices"][0]["delta"]["content"] + + has_usage = c["usage"] is not None + assert not had_usage + if has_usage: + had_usage = True + else: + raise RuntimeError("Expected different payload") + assert not had_usage + assert ( + string + == "**Deep Learning: An Overview**\n=====================================\n\n" + ) + + @pytest.mark.release def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): response = requests.post( diff --git a/nix/client.nix b/nix/client.nix new file mode 100644 index 00000000000..351fd08abb2 --- /dev/null +++ b/nix/client.nix @@ -0,0 +1,21 @@ +{ + buildPythonPackage, + poetry-core, + huggingface-hub, + pydantic, +}: + +buildPythonPackage { + name = "text-generation"; + + src = ../clients/python; + + pyproject = true; + + build-system = [ poetry-core ]; + + dependencies = [ + huggingface-hub + pydantic + ]; +} diff --git a/router/src/lib.rs b/router/src/lib.rs index d8029c724a2..ad8924df0c5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -684,6 +684,7 @@ pub(crate) struct ChatCompletionChunk { pub model: String, pub system_fingerprint: String, pub choices: Vec, + pub usage: Option, } #[derive(Clone, Serialize, ToSchema)] @@ -732,6 +733,7 @@ impl ChatCompletionChunk { created: u64, logprobs: Option, finish_reason: Option, + usage: Option, ) -> Self { let delta = match (delta, tool_calls) { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { @@ -766,6 +768,7 @@ impl ChatCompletionChunk { logprobs, finish_reason, }], + usage, } } } @@ -880,6 +883,18 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub guideline: Option, + + /// Options for streaming response. Only set this when you set stream: true. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub stream_options: Option, +} + +#[derive(Clone, Deserialize, ToSchema, Serialize)] +struct StreamOptions { + /// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value. + #[schema(example = "true")] + include_usage: bool, } pub fn default_tool_prompt() -> String { @@ -1472,6 +1487,27 @@ mod tests { let textmsg: TextMessage = message.into(); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); } + + #[test] + fn test_chat_stream_options() { + let json = json!({ + "model": "", + "stream_options": {"include_usage": true}, + "messages": [{ + "role": "user", + "content": "Hello" + }] + }); + let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap(); + + assert!(matches!( + request.stream_options, + Some(StreamOptions { + include_usage: true + }) + )); + } + #[test] fn openai_output() { let message = OutputMessage::ChatMessage(TextMessage { diff --git a/router/src/server.rs b/router/src/server.rs index 9cec2aaad7d..32c86e0f392 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, - OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, - TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, + TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1175,6 +1175,7 @@ async fn chat_completions( seed, stop, stream, + stream_options, tools, tool_choice, tool_prompt, @@ -1265,6 +1266,28 @@ async fn chat_completions( (content, None) }; + let (usage, finish_reason) = match stream_token.details { + Some(details) => { + let usage = if stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) + { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Some(Usage { + completion_tokens, + prompt_tokens, + total_tokens, + }) + } else { + None + }; + (usage, Some(details.finish_reason.format(true))) + } + None => (None, None), + }; event .json_data(CompletionType::ChatCompletionChunk( ChatCompletionChunk::new( @@ -1274,7 +1297,8 @@ async fn chat_completions( tool_calls, current_time, logprobs, - stream_token.details.map(|d| d.finish_reason.format(true)), + finish_reason, + usage, ), )) .unwrap_or_else(|e| { @@ -1664,6 +1688,7 @@ StreamDetails, ErrorResponse, GrammarType, Usage, +StreamOptions, DeltaToolCall, ToolType, Tool,