From 66833e8c0c7aaeb9506f2ce7b3b649bea7ac786e Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 12:08:51 +0300 Subject: [PATCH 1/6] #7 Made the stream chunk schema actual --- examples/lang/chat_stream_async.py | 2 +- src/glide/lang/router_async.py | 15 +++++++++++---- src/glide/lang/schemas.py | 10 ++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index dd37c31..26fd68a 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -14,7 +14,7 @@ ) router_id: str = "default" # defined in Glide config (see glide.config.yaml) -question = "What are the kosher species?" +question = "What is the capital of Greenland?" async def chat_stream() -> None: diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index 488a7ef..e338570 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -16,6 +16,7 @@ from glide.exceptions import GlideUnavailable, GlideClientError, GlideClientMismatch from glide.lang import schemas from glide.lang.schemas import StreamChatRequest, StreamResponse, ChatRequestId +from glide.logging import logger from glide.typing import RouterId @@ -59,6 +60,7 @@ def request_chat(self, chat_request: StreamChatRequest) -> None: async def chat_stream( self, req: StreamChatRequest + # TODO: add timeout ) -> AsyncGenerator[StreamResponse, None]: chunk_buffer: asyncio.Queue[StreamResponse] = asyncio.Queue() self._response_streams[req.id] = chunk_buffer @@ -109,16 +111,21 @@ async def _receiver(self) -> None: json.loads(raw_chunk), ) + logger.debug("received stream chunk", extra={"chunk": chunk}) + if chunk_buffer := self._response_streams.get(chunk.id): chunk_buffer.put_nowait(chunk) continue self.response_chunks.put_nowait(chunk) - except pydantic.ValidationError as e: - raise GlideClientMismatch( + except pydantic.ValidationError: + logger.error( "Failed to validate Glide API response. " - "Please make sure Glide API and client versions are compatible" - ) from e + "Please make sure Glide API and client versions are compatible", + exc_info=True + ) + except Exception as e: + logger.exception(e) except asyncio.CancelledError: ... diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index 93cbbc6..b73d7c4 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -87,13 +87,11 @@ class ChatStreamChunk(Schema): """ id: ChatRequestId - # TODO: should be required, needs to fix on the Glide side - created: Optional[datetime] = None - provider: Optional[ProviderName] = None - router: Optional[RouterId] = None - model: Optional[ModelName] = None - + provider_id: ProviderName + router_id: RouterId model_id: str + + model_name: ModelName metadata: Optional[Metadata] = None model_response: ModelChunkResponse From 7fd59feb201ea638be9d906f58cc34a831b029c2 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 16:16:31 +0300 Subject: [PATCH 2/6] #7: Adjusted the client according to the latest update in Glide --- examples/lang/chat_stream_async.py | 48 +++++++++++++++++++----------- src/glide/lang/router_async.py | 32 ++++++++++---------- src/glide/lang/schemas.py | 25 ++++++++++------ 3 files changed, 63 insertions(+), 42 deletions(-) diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index 26fd68a..8e12997 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -6,15 +6,13 @@ from glide import AsyncGlideClient from glide.lang.schemas import ( - ChatStreamError, - StreamChatRequest, + ChatStreamRequest, ChatMessage, - StreamResponse, - ChatStreamChunk, + ChatStreamMessage, ) router_id: str = "default" # defined in Glide config (see glide.config.yaml) -question = "What is the capital of Greenland?" +question = "How are you?" async def chat_stream() -> None: @@ -23,31 +21,45 @@ async def chat_stream() -> None: print(f"💬Question: {question}") print("💬Answer: ", end="") - last_chunk: Optional[StreamResponse] = None - chat_req = StreamChatRequest(message=ChatMessage(role="user", content=question)) + last_msg: Optional[ChatStreamMessage] = None + chat_req = ChatStreamRequest(message=ChatMessage(role="user", content=question)) started_at = time.perf_counter() first_chunk_recv_at: Optional[float] = None async with glide_client.lang.stream_client(router_id) as client: - async for chunk in client.chat_stream(chat_req): + async for message in client.chat_stream(chat_req): if not first_chunk_recv_at: first_chunk_recv_at = time.perf_counter() - if isinstance(chunk, ChatStreamError): - print(f"💥err: {chunk.message} (code: {chunk.err_code})") + last_msg = message + + if err := message.error: + print(f"💥err: {err.message} (code: {err.err_code})") + continue + + if chunk := message.chunk: + print(chunk.model_response.message.content, end="", flush=True) continue - print(chunk.model_response.message.content, end="") - last_chunk = chunk + raise RuntimeError(f"unknown message type: {last_msg}") + + if last_msg: + if last_chunk := last_msg.chunk: + if reason := last_chunk.finish_reason: + print( + f"\n\n✅ Generation is done " + f"(provider: {last_chunk.provider_name}, model: {last_chunk.model_name}, reason: {reason.value})" + ) - if last_chunk: - if isinstance(last_chunk, ChatStreamChunk): - if reason := last_chunk.model_response.finish_reason: - print(f"\n✅ Generation is done (reason: {reason.value})") + print( + f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_chunk.model_id})" + ) - if isinstance(last_chunk, ChatStreamError): - print(f"\n💥 Generation ended up with error (reason: {last_chunk.message})") + if err := last_msg.error: + print( + f"\n💥 Generation ended up with error (reason: {err.message}, code: {err.err_code})" + ) first_chunk_duration_ms: float = 0 diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index e338570..bd5a078 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -15,7 +15,7 @@ from glide.exceptions import GlideUnavailable, GlideClientError, GlideClientMismatch from glide.lang import schemas -from glide.lang.schemas import StreamChatRequest, StreamResponse, ChatRequestId +from glide.lang.schemas import ChatStreamRequest, ChatStreamMessage, ChatRequestId from glide.logging import logger from glide.typing import RouterId @@ -43,9 +43,11 @@ def __init__( self._handlers = handlers - self.requests: asyncio.Queue[StreamChatRequest] = asyncio.Queue() - self.response_chunks: asyncio.Queue[StreamResponse] = asyncio.Queue() - self._response_streams: Dict[ChatRequestId, asyncio.Queue[StreamResponse]] = {} + self.requests: asyncio.Queue[ChatStreamRequest] = asyncio.Queue() + self.response_chunks: asyncio.Queue[ChatStreamMessage] = asyncio.Queue() + self._response_streams: Dict[ + ChatRequestId, asyncio.Queue[ChatStreamMessage] + ] = {} self._sender_task: Optional[asyncio.Task] = None self._receiver_task: Optional[asyncio.Task] = None @@ -55,25 +57,26 @@ def __init__( self._ping_interval = ping_interval self._close_timeout = close_timeout - def request_chat(self, chat_request: StreamChatRequest) -> None: + def request_chat(self, chat_request: ChatStreamRequest) -> None: self.requests.put_nowait(chat_request) async def chat_stream( - self, req: StreamChatRequest + self, + req: ChatStreamRequest, # TODO: add timeout - ) -> AsyncGenerator[StreamResponse, None]: - chunk_buffer: asyncio.Queue[StreamResponse] = asyncio.Queue() - self._response_streams[req.id] = chunk_buffer + ) -> AsyncGenerator[ChatStreamMessage, None]: + msg_buffer: asyncio.Queue[ChatStreamMessage] = asyncio.Queue() + self._response_streams[req.id] = msg_buffer self.request_chat(req) while True: - chunk = await chunk_buffer.get() + chunk = await msg_buffer.get() yield chunk # TODO: handle stream end on error - if chunk.model_response.finish_reason: + if chunk.chunk and chunk.chunk.finish_reason: break self._response_streams.pop(req.id, None) @@ -106,9 +109,8 @@ async def _receiver(self) -> None: while self._ws_client and self._ws_client.open: try: raw_chunk = await self._ws_client.recv() - chunk: StreamResponse = pydantic.parse_obj_as( - StreamResponse, - json.loads(raw_chunk), + chunk: ChatStreamMessage = ChatStreamMessage( + **json.loads(raw_chunk) ) logger.debug("received stream chunk", extra={"chunk": chunk}) @@ -122,7 +124,7 @@ async def _receiver(self) -> None: logger.error( "Failed to validate Glide API response. " "Please make sure Glide API and client versions are compatible", - exc_info=True + exc_info=True, ) except Exception as e: logger.exception(e) diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index b73d7c4..0a2e680 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from enum import Enum -from typing import List, Optional, Dict, Any, Union +from typing import List, Optional, Dict, Any from pydantic import Field @@ -18,7 +18,9 @@ class FinishReason(str, Enum): # generation is finished successfully without interruptions COMPLETE = "complete" # generation is interrupted because of the length of the response text - LENGTH = "length" + MAX_TOKENS = "max_tokens" + CONTENT_FILTERED = "content_filtered" + OTHER = "other" class LangRouter(Schema): ... @@ -67,7 +69,7 @@ class ChatResponse(Schema): model_response: ModelResponse -class StreamChatRequest(Schema): +class ChatStreamRequest(Schema): id: ChatRequestId = Field(default_factory=lambda: str(uuid.uuid4())) message: ChatMessage message_history: List[ChatMessage] = Field(default_factory=list) @@ -78,7 +80,6 @@ class StreamChatRequest(Schema): class ModelChunkResponse(Schema): metadata: Optional[Metadata] = None message: ChatMessage - finish_reason: Optional[FinishReason] = None class ChatStreamChunk(Schema): @@ -86,21 +87,27 @@ class ChatStreamChunk(Schema): A response chunk of a streaming chat """ - id: ChatRequestId - provider_id: ProviderName - router_id: RouterId model_id: str + provider_name: ProviderName model_name: ModelName - metadata: Optional[Metadata] = None + model_response: ModelChunkResponse + finish_reason: Optional[FinishReason] = None class ChatStreamError(Schema): id: ChatRequestId err_code: str message: str + + +class ChatStreamMessage(Schema): + id: ChatRequestId + created_at: datetime metadata: Optional[Metadata] = None + router_id: RouterId -StreamResponse = Union[ChatStreamChunk, ChatStreamError] + chunk: Optional[ChatStreamChunk] = None + error: Optional[ChatStreamError] = None From e5c36e0c8fcb01dda3b51583536dc7a638069c88 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 16:23:14 +0300 Subject: [PATCH 3/6] #7: Simplified access to the content chunk --- examples/lang/chat_stream_async.py | 4 ++-- src/glide/lang/schemas.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index 8e12997..5453a97 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -38,8 +38,8 @@ async def chat_stream() -> None: print(f"💥err: {err.message} (code: {err.err_code})") continue - if chunk := message.chunk: - print(chunk.model_response.message.content, end="", flush=True) + if content_chunk := message.content_chunk: + print(content_chunk, end="", flush=True) continue raise RuntimeError(f"unknown message type: {last_msg}") diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index 0a2e680..ebfb384 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -111,3 +111,10 @@ class ChatStreamMessage(Schema): chunk: Optional[ChatStreamChunk] = None error: Optional[ChatStreamError] = None + + @property + def content_chunk(self) -> Optional[str]: + if not self.chunk: + return None + + return self.chunk.model_response.message.content From 740e1dae081742c96cd5b3ab2e9b09de445444a6 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 17:07:34 +0300 Subject: [PATCH 4/6] #7 Raise error if stream was interrupted by exception --- examples/lang/chat_stream_async.py | 64 +++++++++++----------- src/glide/exceptions.py | 21 ++++++-- src/glide/lang/router_async.py | 86 +++++++++++++++++------------- src/glide/lang/schemas.py | 19 +++++++ 4 files changed, 118 insertions(+), 72 deletions(-) diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index 5453a97..c0488ae 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -27,52 +27,54 @@ async def chat_stream() -> None: started_at = time.perf_counter() first_chunk_recv_at: Optional[float] = None - async with glide_client.lang.stream_client(router_id) as client: - async for message in client.chat_stream(chat_req): - if not first_chunk_recv_at: - first_chunk_recv_at = time.perf_counter() + try: + async with glide_client.lang.stream_client(router_id) as client: + async for message in client.chat_stream(chat_req): + if not first_chunk_recv_at: + first_chunk_recv_at = time.perf_counter() - last_msg = message + last_msg = message - if err := message.error: - print(f"💥err: {err.message} (code: {err.err_code})") - continue + if err := message.error: + print(f"💥ERR: {err.message} (code: {err.err_code})") + print("🧹 Restarting the stream") + continue - if content_chunk := message.content_chunk: - print(content_chunk, end="", flush=True) - continue + if content_chunk := message.content_chunk: + print(content_chunk, end="", flush=True) + continue - raise RuntimeError(f"unknown message type: {last_msg}") + raise RuntimeError(f"Unknown message type: {last_msg}") + + if last_msg and last_msg.chunk and last_msg.finish_reason: + # LLM gen context + provider_name = last_msg.chunk.provider_name + model_name = last_msg.chunk.model_name + finish_reason = last_msg.finish_reason - if last_msg: - if last_chunk := last_msg.chunk: - if reason := last_chunk.finish_reason: print( f"\n\n✅ Generation is done " - f"(provider: {last_chunk.provider_name}, model: {last_chunk.model_name}, reason: {reason.value})" + f"(provider: {provider_name}, model: {model_name}, reason: {finish_reason.value})" ) print( - f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_chunk.model_id})" + f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_msg.chunk.model_id})" ) - if err := last_msg.error: - print( - f"\n💥 Generation ended up with error (reason: {err.message}, code: {err.err_code})" - ) - - first_chunk_duration_ms: float = 0 + first_chunk_duration_ms: float = 0 - if first_chunk_recv_at: - first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 - print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") + if first_chunk_recv_at: + first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 + print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") - chat_duration_ms = (time.perf_counter() - started_at) * 1_000 + chat_duration_ms = (time.perf_counter() - started_at) * 1_000 - print( - f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " - f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" - ) + print( + f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " + f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" + ) + except Exception as e: + print(f"💥Stream interrupted by ERR: {e}") if __name__ == "__main__": diff --git a/src/glide/exceptions.py b/src/glide/exceptions.py index bd09682..3c7d435 100644 --- a/src/glide/exceptions.py +++ b/src/glide/exceptions.py @@ -2,19 +2,34 @@ # SPDX-License-Identifier: APACHE-2.0 -class GlideUnavailable(Exception): +class GlideError(Exception): + """The base exception for all Glide server errors""" + + +class GlideUnavailable(GlideError): """ Occurs when Glide API is not available """ -class GlideClientError(Exception): +class GlideClientError(GlideError): """ Occurs when there is an issue with sending a Glide request """ -class GlideClientMismatch(Exception): +class GlideClientMismatch(GlideError): """ Occurs when there is a sign of possible compatibility issues between Glide API and the client version """ + + +class GlideChatStreamError(GlideError): + """ + Occurs when chat stream ends with an error + """ + + def __init__(self, message: str, err_code: str) -> None: + super().__init__(message) + + self.err_code = err_code diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index bd5a078..32dfa5b 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -13,7 +13,12 @@ from websockets import WebSocketClientProtocol -from glide.exceptions import GlideUnavailable, GlideClientError, GlideClientMismatch +from glide.exceptions import ( + GlideUnavailable, + GlideClientError, + GlideClientMismatch, + GlideChatStreamError, +) from glide.lang import schemas from glide.lang.schemas import ChatStreamRequest, ChatStreamMessage, ChatRequestId from glide.logging import logger @@ -70,16 +75,24 @@ async def chat_stream( self.request_chat(req) - while True: - chunk = await msg_buffer.get() + try: + while True: + message = await msg_buffer.get() + + if err := message.ended_with_err: + # fail only on fatal errors that indicates stream stop - yield chunk + raise GlideChatStreamError( + f"Chat stream {req.id} ended with an error: {err.message} (code: {err.err_code})", + err.err_code, + ) - # TODO: handle stream end on error - if chunk.chunk and chunk.chunk.finish_reason: - break + yield message # returns content chunk and some error messages - self._response_streams.pop(req.id, None) + if message.chunk and message.chunk.finish_reason: + break + finally: + self._response_streams.pop(req.id, None) async def start(self) -> None: self._ws_client = await websockets.connect( @@ -95,41 +108,38 @@ async def start(self) -> None: self._receiver_task = asyncio.create_task(self._receiver()) async def _sender(self) -> None: - try: - while self._ws_client and self._ws_client.open: + while self._ws_client and self._ws_client.open: + try: chat_request = await self.requests.get() await self._ws_client.send(chat_request.json()) - except asyncio.CancelledError: - # TODO: log - ... + except asyncio.CancelledError: + # TODO: log + ... async def _receiver(self) -> None: - try: - while self._ws_client and self._ws_client.open: - try: - raw_chunk = await self._ws_client.recv() - chunk: ChatStreamMessage = ChatStreamMessage( - **json.loads(raw_chunk) - ) - - logger.debug("received stream chunk", extra={"chunk": chunk}) - - if chunk_buffer := self._response_streams.get(chunk.id): - chunk_buffer.put_nowait(chunk) - continue - - self.response_chunks.put_nowait(chunk) - except pydantic.ValidationError: - logger.error( - "Failed to validate Glide API response. " - "Please make sure Glide API and client versions are compatible", - exc_info=True, - ) - except Exception as e: - logger.exception(e) - except asyncio.CancelledError: - ... + while self._ws_client and self._ws_client.open: + try: + raw_chunk = await self._ws_client.recv() + chunk: ChatStreamMessage = ChatStreamMessage(**json.loads(raw_chunk)) + + logger.debug("received stream chunk", extra={"chunk": chunk}) + + if chunk_buffer := self._response_streams.get(chunk.id): + chunk_buffer.put_nowait(chunk) + continue + + self.response_chunks.put_nowait(chunk) + except pydantic.ValidationError: + logger.error( + "Failed to validate Glide API response. " + "Please make sure Glide API and client versions are compatible", + exc_info=True, + ) + except asyncio.CancelledError: + ... + except Exception as e: + logger.exception(e) async def stop(self) -> None: if self._sender_task: diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index ebfb384..03f2e6f 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -20,6 +20,7 @@ class FinishReason(str, Enum): # generation is interrupted because of the length of the response text MAX_TOKENS = "max_tokens" CONTENT_FILTERED = "content_filtered" + ERROR = "error" OTHER = "other" @@ -100,6 +101,7 @@ class ChatStreamError(Schema): id: ChatRequestId err_code: str message: str + finish_reason: Optional[FinishReason] = None class ChatStreamMessage(Schema): @@ -112,6 +114,23 @@ class ChatStreamMessage(Schema): chunk: Optional[ChatStreamChunk] = None error: Optional[ChatStreamError] = None + @property + def finish_reason(self) -> Optional[FinishReason]: + if self.chunk and self.chunk.finish_reason: + return self.chunk.finish_reason + + if self.error and self.error.finish_reason: + return self.error.finish_reason + + return None + + @property + def ended_with_err(self) -> Optional[ChatStreamError]: + if self.error and self.error.finish_reason: + return self.error + + return None + @property def content_chunk(self) -> Optional[str]: if not self.chunk: From bcf167652164fcc475c787a7101fdaef0bfa6b6f Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 21:38:07 +0300 Subject: [PATCH 5/6] #7 Fixed issue with empty content chunks and hanging of the client tear down code --- examples/lang/chat_stream_async.py | 62 +++++++++++++++--------------- src/glide/lang/router_async.py | 19 +++++---- src/glide/lang/schemas.py | 7 ++++ 3 files changed, 46 insertions(+), 42 deletions(-) diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index c0488ae..e8c1c15 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -12,7 +12,7 @@ ) router_id: str = "default" # defined in Glide config (see glide.config.yaml) -question = "How are you?" +question = "What is the most complicated theory discovered by humanity?" async def chat_stream() -> None: @@ -27,54 +27,52 @@ async def chat_stream() -> None: started_at = time.perf_counter() first_chunk_recv_at: Optional[float] = None - try: - async with glide_client.lang.stream_client(router_id) as client: + async with glide_client.lang.stream_client(router_id) as client: + try: async for message in client.chat_stream(chat_req): if not first_chunk_recv_at: first_chunk_recv_at = time.perf_counter() last_msg = message + if message.chunk: + print(message.content_chunk, end="", flush=True) + continue + if err := message.error: print(f"💥ERR: {err.message} (code: {err.err_code})") print("🧹 Restarting the stream") continue - if content_chunk := message.content_chunk: - print(content_chunk, end="", flush=True) - continue - - raise RuntimeError(f"Unknown message type: {last_msg}") - - if last_msg and last_msg.chunk and last_msg.finish_reason: - # LLM gen context - provider_name = last_msg.chunk.provider_name - model_name = last_msg.chunk.model_name - finish_reason = last_msg.finish_reason + print(f"😮Unknown message type: {message}") + except Exception as e: + print(f"💥Stream interrupted by ERR: {e}") - print( - f"\n\n✅ Generation is done " - f"(provider: {provider_name}, model: {model_name}, reason: {finish_reason.value})" - ) + if last_msg and last_msg.chunk and last_msg.finish_reason: + # LLM gen context + provider_name = last_msg.chunk.provider_name + model_name = last_msg.chunk.model_name + finish_reason = last_msg.finish_reason - print( - f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_msg.chunk.model_id})" - ) + print( + f"\n\n✅ Generation is done " + f"(provider: {provider_name}, model: {model_name}, reason: {finish_reason.value})" + ) - first_chunk_duration_ms: float = 0 + print( + f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_msg.chunk.model_id})" + ) - if first_chunk_recv_at: - first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 - print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") + if first_chunk_recv_at: + first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 + print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") - chat_duration_ms = (time.perf_counter() - started_at) * 1_000 + chat_duration_ms = (time.perf_counter() - started_at) * 1_000 - print( - f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " - f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" - ) - except Exception as e: - print(f"💥Stream interrupted by ERR: {e}") + print( + f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " + f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" + ) if __name__ == "__main__": diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index 32dfa5b..ab95d3c 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -80,8 +80,7 @@ async def chat_stream( message = await msg_buffer.get() if err := message.ended_with_err: - # fail only on fatal errors that indicates stream stop - + # fail only on fatal errors that indicate stream stop raise GlideChatStreamError( f"Chat stream {req.id} ended with an error: {err.message} (code: {err.err_code})", err.err_code, @@ -89,7 +88,7 @@ async def chat_stream( yield message # returns content chunk and some error messages - if message.chunk and message.chunk.finish_reason: + if message.finish_reason: break finally: self._response_streams.pop(req.id, None) @@ -115,21 +114,21 @@ async def _sender(self) -> None: await self._ws_client.send(chat_request.json()) except asyncio.CancelledError: # TODO: log - ... + break async def _receiver(self) -> None: while self._ws_client and self._ws_client.open: try: raw_chunk = await self._ws_client.recv() - chunk: ChatStreamMessage = ChatStreamMessage(**json.loads(raw_chunk)) + message: ChatStreamMessage = ChatStreamMessage(**json.loads(raw_chunk)) - logger.debug("received stream chunk", extra={"chunk": chunk}) + logger.debug("received chat stream message", extra={"message": message}) - if chunk_buffer := self._response_streams.get(chunk.id): - chunk_buffer.put_nowait(chunk) + if msg_buffer := self._response_streams.get(message.id): + msg_buffer.put_nowait(message) continue - self.response_chunks.put_nowait(chunk) + self.response_chunks.put_nowait(message) except pydantic.ValidationError: logger.error( "Failed to validate Glide API response. " @@ -137,7 +136,7 @@ async def _receiver(self) -> None: exc_info=True, ) except asyncio.CancelledError: - ... + break except Exception as e: logger.exception(e) diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index 03f2e6f..e32a632 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -133,6 +133,13 @@ def ended_with_err(self) -> Optional[ChatStreamError]: @property def content_chunk(self) -> Optional[str]: + """ + Returns received text generation chunk. + + Be careful with using this method to see if there is a chunk (rather than an error), + because content can be an empty string with some providers like OpenAI. + Better check for `self.chunk` in that case. + """ if not self.chunk: return None From 2ccf16df362c94c85956508c798d53cff7d0a4d8 Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Mon, 15 Apr 2024 21:39:13 +0300 Subject: [PATCH 6/6] #7 Leave todos --- src/glide/lang/router_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index ab95d3c..1d1b15b 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -141,6 +141,7 @@ async def _receiver(self) -> None: logger.exception(e) async def stop(self) -> None: + # TODO: allow to timeout shutdown too if self._sender_task: self._sender_task.cancel() await self._sender_task