33import os
44import time
55from io import BytesIO
6- from typing import Iterator , Any , List
6+ from typing import Any , AsyncGenerator , Iterator , List
77
88from gptcache import cache
9- from gptcache .adapter .adapter import adapt
9+ from gptcache .adapter .adapter import aadapt , adapt
1010from gptcache .adapter .base import BaseCacheLLM
1111from gptcache .manager .scalar_data .base import Answer , DataType
1212from gptcache .utils import import_openai , import_pillow
1313from gptcache .utils .error import wrap_error
1414from gptcache .utils .response import (
15- get_stream_message_from_openai_answer ,
16- get_message_from_openai_answer ,
17- get_text_from_openai_answer ,
15+ get_audio_text_from_openai_answer ,
1816 get_image_from_openai_b64 ,
1917 get_image_from_openai_url ,
20- get_audio_text_from_openai_answer ,
18+ get_message_from_openai_answer ,
19+ get_stream_message_from_openai_answer ,
20+ get_text_from_openai_answer ,
2121)
2222from gptcache .utils .token import token_counter
2323
@@ -56,15 +56,40 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM):
5656 @classmethod
5757 def _llm_handler (cls , * llm_args , ** llm_kwargs ):
5858 try :
59- return super ().create (* llm_args , ** llm_kwargs ) if cls .llm is None else cls .llm (* llm_args , ** llm_kwargs )
59+ return (
60+ super ().create (* llm_args , ** llm_kwargs )
61+ if cls .llm is None
62+ else cls .llm (* llm_args , ** llm_kwargs )
63+ )
64+ except openai .OpenAIError as e :
65+ raise wrap_error (e ) from e
66+
67+ @classmethod
68+ async def _allm_handler (cls , * llm_args , ** llm_kwargs ):
69+ try :
70+ return (
71+ (await super ().acreate (* llm_args , ** llm_kwargs ))
72+ if cls .llm is None
73+ else await cls .llm (* llm_args , ** llm_kwargs )
74+ )
6075 except openai .OpenAIError as e :
6176 raise wrap_error (e ) from e
6277
6378 @staticmethod
6479 def _update_cache_callback (
6580 llm_data , update_cache_func , * args , ** kwargs
6681 ): # pylint: disable=unused-argument
67- if not isinstance (llm_data , Iterator ):
82+ if isinstance (llm_data , AsyncGenerator ):
83+
84+ async def hook_openai_data (it ):
85+ total_answer = ""
86+ async for item in it :
87+ total_answer += get_stream_message_from_openai_answer (item )
88+ yield item
89+ update_cache_func (Answer (total_answer , DataType .STR ))
90+
91+ return hook_openai_data (llm_data )
92+ elif not isinstance (llm_data , Iterator ):
6893 update_cache_func (
6994 Answer (get_message_from_openai_answer (llm_data ), DataType .STR )
7095 )
@@ -92,8 +117,6 @@ def cache_data_convert(cache_data):
92117 saved_token = [input_token , output_token ]
93118 else :
94119 saved_token = [0 , 0 ]
95- if kwargs .get ("stream" , False ):
96- return _construct_stream_resp_from_cache (cache_data , saved_token )
97120 return _construct_resp_from_cache (cache_data , saved_token )
98121
99122 kwargs = cls .fill_base_args (** kwargs )
@@ -105,6 +128,38 @@ def cache_data_convert(cache_data):
105128 ** kwargs ,
106129 )
107130
131+ @classmethod
132+ async def acreate (cls , * args , ** kwargs ):
133+ chat_cache = kwargs .get ("cache_obj" , cache )
134+ enable_token_counter = chat_cache .config .enable_token_counter
135+
136+ def cache_data_convert (cache_data ):
137+ if enable_token_counter :
138+ input_token = _num_tokens_from_messages (kwargs .get ("messages" ))
139+ output_token = token_counter (cache_data )
140+ saved_token = [input_token , output_token ]
141+ else :
142+ saved_token = [0 , 0 ]
143+ if kwargs .get ("stream" , False ):
144+ return async_iter (
145+ _construct_stream_resp_from_cache (cache_data , saved_token )
146+ )
147+ return _construct_resp_from_cache (cache_data , saved_token )
148+
149+ kwargs = cls .fill_base_args (** kwargs )
150+ return await aadapt (
151+ cls ._allm_handler ,
152+ cache_data_convert ,
153+ cls ._update_cache_callback ,
154+ * args ,
155+ ** kwargs ,
156+ )
157+
158+
159+ async def async_iter (input_list ):
160+ for item in input_list :
161+ yield item
162+
108163
109164class Completion (openai .Completion , BaseCacheLLM ):
110165 """Openai Completion Wrapper
@@ -128,7 +183,22 @@ class Completion(openai.Completion, BaseCacheLLM):
128183 @classmethod
129184 def _llm_handler (cls , * llm_args , ** llm_kwargs ):
130185 try :
131- return super ().create (* llm_args , ** llm_kwargs ) if not cls .llm else cls .llm (* llm_args , ** llm_kwargs )
186+ return (
187+ super ().create (* llm_args , ** llm_kwargs )
188+ if not cls .llm
189+ else cls .llm (* llm_args , ** llm_kwargs )
190+ )
191+ except openai .OpenAIError as e :
192+ raise wrap_error (e ) from e
193+
194+ @classmethod
195+ async def _allm_handler (cls , * llm_args , ** llm_kwargs ):
196+ try :
197+ return (
198+ (await super ().acreate (* llm_args , ** llm_kwargs ))
199+ if cls .llm is None
200+ else await cls .llm (* llm_args , ** llm_kwargs )
201+ )
132202 except openai .OpenAIError as e :
133203 raise wrap_error (e ) from e
134204
@@ -154,6 +224,17 @@ def create(cls, *args, **kwargs):
154224 ** kwargs ,
155225 )
156226
227+ @classmethod
228+ async def acreate (cls , * args , ** kwargs ):
229+ kwargs = cls .fill_base_args (** kwargs )
230+ return await aadapt (
231+ cls ._allm_handler ,
232+ cls ._cache_data_convert ,
233+ cls ._update_cache_callback ,
234+ * args ,
235+ ** kwargs ,
236+ )
237+
157238
158239class Audio (openai .Audio ):
159240 """Openai Audio Wrapper
@@ -319,7 +400,11 @@ class Moderation(openai.Moderation, BaseCacheLLM):
319400 @classmethod
320401 def _llm_handler (cls , * llm_args , ** llm_kwargs ):
321402 try :
322- return super ().create (* llm_args , ** llm_kwargs ) if not cls .llm else cls .llm (* llm_args , ** llm_kwargs )
403+ return (
404+ super ().create (* llm_args , ** llm_kwargs )
405+ if not cls .llm
406+ else cls .llm (* llm_args , ** llm_kwargs )
407+ )
323408 except openai .OpenAIError as e :
324409 raise wrap_error (e ) from e
325410
0 commit comments