From fe30be6fbae807c35f0bab8c0d164d8272f992fc Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Wed, 1 Mar 2023 21:55:43 -0800 Subject: [PATCH] add async and streaming support to `OpenAIChat` (#1378) title says it all --- docs/modules/llms/async_llm.ipynb | 65 +++--- docs/modules/llms/streaming_llm.ipynb | 113 +++++++++-- langchain/llms/base.py | 10 +- langchain/llms/openai.py | 213 +++++++++++--------- tests/integration_tests/llms/test_openai.py | 56 ++++- 5 files changed, 319 insertions(+), 138 deletions(-) diff --git a/docs/modules/llms/async_llm.ipynb b/docs/modules/llms/async_llm.ipynb index 730d010210359..ac31af8deca4f 100644 --- a/docs/modules/llms/async_llm.ipynb +++ b/docs/modules/llms/async_llm.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "f6574496-b360-4ffa-9523-7fd34a590164", "metadata": {}, @@ -10,14 +9,14 @@ "\n", "LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n", "\n", - "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` and `PromptLayerOpenAI` is supported, but async support for other LLMs is on the roadmap.\n", + "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` `OpenAIChat`, and `PromptLayerOpenAI` are supported, but async support for other LLMs is on the roadmap.\n", "\n", "You can use the `agenerate` method to call an OpenAI LLM asynchronously." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "id": "5e49e96c-0f88-466d-b3d3-ea0966bdf19e", "metadata": { "tags": [] @@ -29,64 +28,66 @@ "text": [ "\n", "\n", - "I'm doing well. How about you?\n", + "As an AI language model, I don't have feelings like humans, but I'm functioning properly. How may I assist you?\n", "\n", "\n", - "I'm doing well, thank you. How about you?\n", + "I'm an AI language model, so I don't have emotions, but I'm functioning properly. How may I assist you today?\n", "\n", "\n", - "I'm doing well, thank you. How about you?\n", + "As an AI language model, I do not have emotions like humans, but I'm functioning normally. How can I assist you today?\n", "\n", "\n", - "I'm doing well, thank you. How about you?\n", + "I am an AI language model, so I do not have feelings, but I am here to assist you. How may I help you today?\n", "\n", - "I am doing quite well. How about you?\n", "\n", + "As an AI language model, I do not have feelings or emotions but I'm always ready to assist you. How may I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I'm functioning normally. How may I assist you today?\n", "\n", - "I'm doing great, thank you! How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I'm functioning properly. Thank you. How may I assist you today?\n", "\n", - "I'm doing well, thanks for asking. How about you?\n", "\n", + "As an AI language model, I don't have emotions, so I don't have a specific feeling or emotion. How can I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I do not have feelings or emotions. However, I am functioning as intended and ready to assist you with any queries you may have. How can I be of assistance today?\n", "\n", - "I'm doing well, thank you. How about you?\n", - "\u001b[1mConcurrent executed in 1.93 seconds.\u001b[0m\n", "\n", + "As an AI language model, I do not have feelings, but I am functioning well. Thank you for asking. How can I assist you today?\n", + "\u001b[1mConcurrent executed in 0.92 seconds.\u001b[0m\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I'm functioning well. How can I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I'm functioning well. Thank you for asking. How may I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "I'm an AI language model, so I don't have feelings, but I'm functioning well. How can I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I'm functioning well. Thank you for asking. How may I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", + "As an AI language model, I don't have feelings, but I am functioning well. How can I assist you today?\n", "\n", - "I'm doing well, thank you. How about you?\n", "\n", - "I'm doing well, thank you. How about you?\n", + "As an AI language model, I don't have feelings but I'm functioning well. How can I assist you today?\n", "\n", "\n", - "I'm doing well, thank you. How about you?\n", + "As an AI language model, I do not have personal emotions. However, I am functioning well and ready to assist you with any queries or tasks you have. How may I assist you today?\n", "\n", "\n", - "I'm doing well, thank you. How about you?\n", + "As an AI language model, I do not have feelings or emotions, but I'm functioning well. How can I assist you today?\n", "\n", "\n", - "I'm doing great, thank you. How about you?\n", - "\u001b[1mSerial executed in 10.54 seconds.\u001b[0m\n" + "I am an AI language model and do not have feelings. But I am functioning properly and ready to assist you with any task. How may I help you today?\n", + "\n", + "\n", + "As an AI language model, I do not have emotions, but I am functioning well. How can I assist you today?\n", + "\u001b[1mSerial executed in 5.00 seconds.\u001b[0m\n" ] } ], @@ -94,10 +95,10 @@ "import time\n", "import asyncio\n", "\n", - "from langchain.llms import OpenAI\n", + "from langchain.llms import OpenAIChat\n", "\n", "def generate_serially():\n", - " llm = OpenAI(temperature=0.9)\n", + " llm = OpenAIChat(temperature=0.9)\n", " for _ in range(10):\n", " resp = llm.generate([\"Hello, how are you?\"])\n", " print(resp.generations[0][0].text)\n", @@ -109,7 +110,7 @@ "\n", "\n", "async def generate_concurrently():\n", - " llm = OpenAI(temperature=0.9)\n", + " llm = OpenAIChat(temperature=0.9)\n", " tasks = [async_generate(llm) for _ in range(10)]\n", " await asyncio.gather(*tasks)\n", "\n", @@ -125,6 +126,14 @@ "elapsed = time.perf_counter() - s\n", "print('\\033[1m' + f\"Serial executed in {elapsed:0.2f} seconds.\" + '\\033[0m')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1d3a966-3a27-44e8-9441-ed72f01b86f4", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/modules/llms/streaming_llm.ipynb b/docs/modules/llms/streaming_llm.ipynb index f1b5f2c0732c2..560a6a2482c0f 100644 --- a/docs/modules/llms/streaming_llm.ipynb +++ b/docs/modules/llms/streaming_llm.ipynb @@ -7,12 +7,12 @@ "source": [ "# Streaming with LLMs\n", "\n", - "LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()." + "LangChain provides streaming support for LLMs. Currently, we only support streaming for the `OpenAI` and `OpenAIChat` LLM implementation, but streaming support for other LLM implementations is on the roadmap. To utilize streaming, use a [`CallbackHandler`](https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/base.py) that implements `on_llm_new_token`. In this example, we are using [`StreamingStdOutCallbackHandler`]()." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07", "metadata": { "tags": [] @@ -27,43 +27,43 @@ "Verse 1\n", "I'm sippin' on sparkling water,\n", "It's so refreshing and light,\n", - "It's the perfect way to quench my thirst,\n", + "It's the perfect way to quench my thirst\n", "On a hot summer night.\n", "\n", "Chorus\n", "Sparkling water, sparkling water,\n", "It's the best way to stay hydrated,\n", - "It's so refreshing and light,\n", - "It's the perfect way to stay alive.\n", + "It's so crisp and so clean,\n", + "It's the perfect way to stay refreshed.\n", "\n", "Verse 2\n", "I'm sippin' on sparkling water,\n", "It's so bubbly and bright,\n", - "It's the perfect way to cool me down,\n", + "It's the perfect way to cool me down\n", "On a hot summer night.\n", "\n", "Chorus\n", "Sparkling water, sparkling water,\n", "It's the best way to stay hydrated,\n", - "It's so refreshing and light,\n", - "It's the perfect way to stay alive.\n", + "It's so crisp and so clean,\n", + "It's the perfect way to stay refreshed.\n", "\n", "Verse 3\n", "I'm sippin' on sparkling water,\n", - "It's so crisp and clean,\n", - "It's the perfect way to keep me going,\n", - "On a hot summer day.\n", + "It's so light and so clear,\n", + "It's the perfect way to keep me cool\n", + "On a hot summer night.\n", "\n", "Chorus\n", "Sparkling water, sparkling water,\n", "It's the best way to stay hydrated,\n", - "It's so refreshing and light,\n", - "It's the perfect way to stay alive." + "It's so crisp and so clean,\n", + "It's the perfect way to stay refreshed." ] } ], "source": [ - "from langchain.llms import OpenAI\n", + "from langchain.llms import OpenAI, OpenAIChat\n", "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "\n", @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "id": "a35373f1-9ee6-4753-a343-5aee749b8527", "metadata": { "tags": [] @@ -103,10 +103,10 @@ { "data": { "text/plain": [ - "LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {}})" + "LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when it hit the wall?\\nA: Dam!', generation_info={'finish_reason': None, 'logprobs': None})]], llm_output={'token_usage': {}})" ] }, - "execution_count": 8, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -114,6 +114,85 @@ "source": [ "llm.generate([\"Tell me a joke.\"])" ] + }, + { + "cell_type": "markdown", + "id": "a93a4d61-0476-49db-8321-7de92bd74059", + "metadata": {}, + "source": [ + "Here's an example with `OpenAIChat`:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "22665f16-e05b-473c-a4bd-ad75744ea024", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Verse 1:\n", + "Bubbles rising to the top\n", + "A refreshing drink that never stops\n", + "Clear and crisp, it's pure delight\n", + "A taste that's sure to excite\n", + "\n", + "Chorus:\n", + "Sparkling water, oh so fine\n", + "A drink that's always on my mind\n", + "With every sip, I feel alive\n", + "Sparkling water, you're my vibe\n", + "\n", + "Verse 2:\n", + "No sugar, no calories, just pure bliss\n", + "A drink that's hard to resist\n", + "It's the perfect way to quench my thirst\n", + "A drink that always comes first\n", + "\n", + "Chorus:\n", + "Sparkling water, oh so fine\n", + "A drink that's always on my mind\n", + "With every sip, I feel alive\n", + "Sparkling water, you're my vibe\n", + "\n", + "Bridge:\n", + "From the mountains to the sea\n", + "Sparkling water, you're the key\n", + "To a healthy life, a happy soul\n", + "A drink that makes me feel whole\n", + "\n", + "Chorus:\n", + "Sparkling water, oh so fine\n", + "A drink that's always on my mind\n", + "With every sip, I feel alive\n", + "Sparkling water, you're my vibe\n", + "\n", + "Outro:\n", + "Sparkling water, you're the one\n", + "A drink that's always so much fun\n", + "I'll never let you go, my friend\n", + "Sparkling water, until the end." + ] + } + ], + "source": [ + "llm = OpenAIChat(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "resp = llm(\"Write me a song about sparkling water.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 43db1556b5667..35c7be9ddfd0b 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -319,6 +319,10 @@ class LLM(BaseLLM): def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Run the LLM on the given prompt and input.""" + async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Run the LLM on the given prompt and input.""" + raise NotImplementedError("Async generation not implemented for this LLM.") + def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: @@ -334,4 +338,8 @@ async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: """Run the LLM on the given prompt and input.""" - raise NotImplementedError("Async generation not implemented for this LLM.") + generations = [] + for prompt in prompts: + text = await self._acall(prompt, stop=stop) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 90ef6bdda6725..f0646c4655919 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,4 +1,6 @@ """Wrapper around OpenAI APIs.""" +from __future__ import annotations + import logging import sys from typing import ( @@ -63,6 +65,53 @@ def _streaming_response_template() -> Dict[str, Any]: } +def _create_retry_decorator(llm: Union[BaseOpenAI, OpenAIChat]) -> Callable[[Any], Any]: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def completion_with_retry(llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return llm.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + + +async def acompletion_with_retry( + llm: Union[BaseOpenAI, OpenAIChat], **kwargs: Any +) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + # Use OpenAI's async api https://github.com/openai/openai-python#async-api + return await llm.client.acreate(**kwargs) + + return await _completion_with_retry(**kwargs) + + class BaseOpenAI(BaseLLM, BaseModel): """Wrapper around OpenAI large language models. @@ -174,48 +223,6 @@ def _default_params(self) -> Dict[str, Any]: } return {**normal_params, **self.model_kwargs} - def _create_retry_decorator(self) -> Callable[[Any], Any]: - import openai - - min_seconds = 4 - max_seconds = 10 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - return retry( - reraise=True, - stop=stop_after_attempt(self.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - - def completion_with_retry(self, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = self._create_retry_decorator() - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return self.client.create(**kwargs) - - return _completion_with_retry(**kwargs) - - async def acompletion_with_retry(self, **kwargs: Any) -> Any: - """Use tenacity to retry the async completion call.""" - retry_decorator = self._create_retry_decorator() - - @retry_decorator - async def _completion_with_retry(**kwargs: Any) -> Any: - # Use OpenAI's async api https://github.com/openai/openai-python#async-api - return await self.client.acreate(**kwargs) - - return await _completion_with_retry(**kwargs) - def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: @@ -247,8 +254,8 @@ def _generate( raise ValueError("Cannot stream results with multiple prompts.") params["stream"] = True response = _streaming_response_template() - for stream_resp in self.completion_with_retry( - prompt=_prompts, **params + for stream_resp in completion_with_retry( + self, prompt=_prompts, **params ): self.callback_manager.on_llm_new_token( stream_resp["choices"][0]["text"], @@ -258,7 +265,7 @@ def _generate( _update_response(response, stream_resp) choices.extend(response["choices"]) else: - response = self.completion_with_retry(prompt=_prompts, **params) + response = completion_with_retry(self, prompt=_prompts, **params) choices.extend(response["choices"]) if not self.streaming: # Can't update token usage if streaming @@ -282,8 +289,8 @@ async def _agenerate( raise ValueError("Cannot stream results with multiple prompts.") params["stream"] = True response = _streaming_response_template() - async for stream_resp in await self.acompletion_with_retry( - prompt=_prompts, **params + async for stream_resp in await acompletion_with_retry( + self, prompt=_prompts, **params ): if self.callback_manager.is_async: await self.callback_manager.on_llm_new_token( @@ -300,7 +307,7 @@ async def _agenerate( _update_response(response, stream_resp) choices.extend(response["choices"]) else: - response = await self.acompletion_with_retry(prompt=_prompts, **params) + response = await acompletion_with_retry(self, prompt=_prompts, **params) choices.extend(response["choices"]) if not self.streaming: # Can't update token usage if streaming @@ -540,6 +547,9 @@ class OpenAIChat(BaseLLM, BaseModel): max_retries: int = 6 """Maximum number of retries to make when generating.""" prefix_messages: List = Field(default_factory=list) + """Series of messages for Chat input.""" + streaming: bool = False + """Whether to stream the results or not.""" class Config: """Configuration for this pydantic object.""" @@ -590,61 +600,82 @@ def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" return self.model_kwargs - def _create_retry_decorator(self) -> Callable[[Any], Any]: - import openai - - min_seconds = 4 - max_seconds = 10 - # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards - return retry( - reraise=True, - stop=stop_after_attempt(self.max_retries), - wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - - def completion_with_retry(self, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = self._create_retry_decorator() - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return self.client.create(**kwargs) - - return _completion_with_retry(**kwargs) - - def _generate( + def _get_chat_params( self, prompts: List[str], stop: Optional[List[str]] = None - ) -> LLMResult: + ) -> Tuple: if len(prompts) > 1: - raise ValueError(f"OpenAIChat only supports single prompts, got {prompts}") + raise ValueError( + f"OpenAIChat currently only supports single prompt, got {prompts}" + ) messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop - response = self.completion_with_retry(messages=messages, **params) - return LLMResult( - generations=[ - [Generation(text=response["choices"][0]["message"]["content"])] - ], - llm_output={"token_usage": response["usage"]}, - ) + return messages, params + + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + messages, params = self._get_chat_params(prompts, stop) + if self.streaming: + response = "" + params["stream"] = True + for stream_resp in completion_with_retry(self, messages=messages, **params): + token = stream_resp["choices"][0]["delta"].get("content", "") + response += token + self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + return LLMResult( + generations=[[Generation(text=response)]], + ) + else: + full_response = completion_with_retry(self, messages=messages, **params) + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output={"token_usage": full_response["usage"]}, + ) async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - raise NotImplementedError("Async generation not implemented for this LLM.") + messages, params = self._get_chat_params(prompts, stop) + if self.streaming: + response = "" + params["stream"] = True + async for stream_resp in await acompletion_with_retry( + self, messages=messages, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + response += token + if self.callback_manager.is_async: + await self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + else: + self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + return LLMResult( + generations=[[Generation(text=response)]], + ) + else: + full_response = await acompletion_with_retry( + self, messages=messages, **params + ) + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output={"token_usage": full_response["usage"]}, + ) @property def _identifying_params(self) -> Mapping[str, Any]: diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index d0bc9f9d1b4de..d5f7e61bbd8ff 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -7,7 +7,7 @@ from langchain.callbacks.base import CallbackManager from langchain.llms.loading import load_llm -from langchain.llms.openai import OpenAI +from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -142,3 +142,57 @@ async def test_openai_async_streaming_callback() -> None: result = await llm.agenerate(["Write me a sentence with 100 words."]) assert callback_handler.llm_streams == 10 assert isinstance(result, LLMResult) + + +def test_openai_chat() -> None: + """Test OpenAIChat.""" + llm = OpenAIChat(max_tokens=10) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_openai_chat_streaming() -> None: + """Test OpenAIChat with streaming option.""" + llm = OpenAIChat(max_tokens=10, streaming=True) + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_openai_chat_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = OpenAIChat( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + llm("Write me a sentence with 100 words.") + assert callback_handler.llm_streams != 0 + + +@pytest.mark.asyncio +async def test_openai_chat_async_generate() -> None: + """Test async chat.""" + llm = OpenAIChat(max_tokens=10) + output = await llm.agenerate(["Hello, how are you?"]) + assert isinstance(output, LLMResult) + + +@pytest.mark.asyncio +async def test_openai_chat_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = OpenAIChat( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + result = await llm.agenerate(["Write me a sentence with 100 words."]) + assert callback_handler.llm_streams != 0 + assert isinstance(result, LLMResult)