Skip to content

Commit 4eab996

Browse files
RestedSimFGpranaychandekar
authored
Support caching of async completion and cache completion (zilliztech#513)
* Use the old version for the chromadb (zilliztech#492) Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * added support for weaviate vector databse (zilliztech#493) * added support for weaviate vector databse Signed-off-by: pranaychandekar <pranayc6@gmail.com> * added support for in local db for weaviate vector store Signed-off-by: pranaychandekar <pranayc6@gmail.com> * added unit test case for weaviate vector store Signed-off-by: pranaychandekar <pranayc6@gmail.com> * resolved unit test case error for weaviate vector store Signed-off-by: pranaychandekar <pranayc6@gmail.com> * increased code coverage resolved pylint issues pylint: disabled C0413 Signed-off-by: pranaychandekar <pranayc6@gmail.com> --------- Signed-off-by: pranaychandekar <pranayc6@gmail.com> Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * Update the version to `0.1.37` (zilliztech#494) Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✨ support caching of async completion and cache completion Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✨ add streaming support for chatcompletion Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✅ improve test coverage and formatting Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✨ support caching of async completion and cache completion Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✨ add streaming support for chatcompletion Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * ✅ improve test coverage and formatting Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * correct merge duplication Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * correct update cache callback Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * add additional tests for improved coverage Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> * remove redundant param in docstring Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> --------- Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: Reuben Thomas-Davis <reubenestd@gmail.com> Signed-off-by: pranaychandekar <pranayc6@gmail.com> Co-authored-by: SimFG <bang.fu@zilliz.com> Co-authored-by: Pranay Chandekar <pranayc6@gmail.com>
1 parent 637a4b5 commit 4eab996

File tree

3 files changed

+403
-34
lines changed

3 files changed

+403
-34
lines changed

gptcache/adapter/adapter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,6 @@ def update_cache_func(handled_llm_data, question=None):
513513
== 0
514514
):
515515
chat_cache.flush()
516-
517516
llm_data = update_cache_callback(
518517
llm_data, update_cache_func, *args, **kwargs
519518
)

gptcache/adapter/openai.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
import os
44
import time
55
from io import BytesIO
6-
from typing import Iterator, Any, List
6+
from typing import Any, AsyncGenerator, Iterator, List
77

88
from gptcache import cache
9-
from gptcache.adapter.adapter import adapt
9+
from gptcache.adapter.adapter import aadapt, adapt
1010
from gptcache.adapter.base import BaseCacheLLM
1111
from gptcache.manager.scalar_data.base import Answer, DataType
1212
from gptcache.utils import import_openai, import_pillow
1313
from gptcache.utils.error import wrap_error
1414
from gptcache.utils.response import (
15-
get_stream_message_from_openai_answer,
16-
get_message_from_openai_answer,
17-
get_text_from_openai_answer,
15+
get_audio_text_from_openai_answer,
1816
get_image_from_openai_b64,
1917
get_image_from_openai_url,
20-
get_audio_text_from_openai_answer,
18+
get_message_from_openai_answer,
19+
get_stream_message_from_openai_answer,
20+
get_text_from_openai_answer,
2121
)
2222
from gptcache.utils.token import token_counter
2323

@@ -56,15 +56,40 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM):
5656
@classmethod
5757
def _llm_handler(cls, *llm_args, **llm_kwargs):
5858
try:
59-
return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs)
59+
return (
60+
super().create(*llm_args, **llm_kwargs)
61+
if cls.llm is None
62+
else cls.llm(*llm_args, **llm_kwargs)
63+
)
64+
except openai.OpenAIError as e:
65+
raise wrap_error(e) from e
66+
67+
@classmethod
68+
async def _allm_handler(cls, *llm_args, **llm_kwargs):
69+
try:
70+
return (
71+
(await super().acreate(*llm_args, **llm_kwargs))
72+
if cls.llm is None
73+
else await cls.llm(*llm_args, **llm_kwargs)
74+
)
6075
except openai.OpenAIError as e:
6176
raise wrap_error(e) from e
6277

6378
@staticmethod
6479
def _update_cache_callback(
6580
llm_data, update_cache_func, *args, **kwargs
6681
): # pylint: disable=unused-argument
67-
if not isinstance(llm_data, Iterator):
82+
if isinstance(llm_data, AsyncGenerator):
83+
84+
async def hook_openai_data(it):
85+
total_answer = ""
86+
async for item in it:
87+
total_answer += get_stream_message_from_openai_answer(item)
88+
yield item
89+
update_cache_func(Answer(total_answer, DataType.STR))
90+
91+
return hook_openai_data(llm_data)
92+
elif not isinstance(llm_data, Iterator):
6893
update_cache_func(
6994
Answer(get_message_from_openai_answer(llm_data), DataType.STR)
7095
)
@@ -92,8 +117,6 @@ def cache_data_convert(cache_data):
92117
saved_token = [input_token, output_token]
93118
else:
94119
saved_token = [0, 0]
95-
if kwargs.get("stream", False):
96-
return _construct_stream_resp_from_cache(cache_data, saved_token)
97120
return _construct_resp_from_cache(cache_data, saved_token)
98121

99122
kwargs = cls.fill_base_args(**kwargs)
@@ -105,6 +128,38 @@ def cache_data_convert(cache_data):
105128
**kwargs,
106129
)
107130

131+
@classmethod
132+
async def acreate(cls, *args, **kwargs):
133+
chat_cache = kwargs.get("cache_obj", cache)
134+
enable_token_counter = chat_cache.config.enable_token_counter
135+
136+
def cache_data_convert(cache_data):
137+
if enable_token_counter:
138+
input_token = _num_tokens_from_messages(kwargs.get("messages"))
139+
output_token = token_counter(cache_data)
140+
saved_token = [input_token, output_token]
141+
else:
142+
saved_token = [0, 0]
143+
if kwargs.get("stream", False):
144+
return async_iter(
145+
_construct_stream_resp_from_cache(cache_data, saved_token)
146+
)
147+
return _construct_resp_from_cache(cache_data, saved_token)
148+
149+
kwargs = cls.fill_base_args(**kwargs)
150+
return await aadapt(
151+
cls._allm_handler,
152+
cache_data_convert,
153+
cls._update_cache_callback,
154+
*args,
155+
**kwargs,
156+
)
157+
158+
159+
async def async_iter(input_list):
160+
for item in input_list:
161+
yield item
162+
108163

109164
class Completion(openai.Completion, BaseCacheLLM):
110165
"""Openai Completion Wrapper
@@ -128,7 +183,22 @@ class Completion(openai.Completion, BaseCacheLLM):
128183
@classmethod
129184
def _llm_handler(cls, *llm_args, **llm_kwargs):
130185
try:
131-
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
186+
return (
187+
super().create(*llm_args, **llm_kwargs)
188+
if not cls.llm
189+
else cls.llm(*llm_args, **llm_kwargs)
190+
)
191+
except openai.OpenAIError as e:
192+
raise wrap_error(e) from e
193+
194+
@classmethod
195+
async def _allm_handler(cls, *llm_args, **llm_kwargs):
196+
try:
197+
return (
198+
(await super().acreate(*llm_args, **llm_kwargs))
199+
if cls.llm is None
200+
else await cls.llm(*llm_args, **llm_kwargs)
201+
)
132202
except openai.OpenAIError as e:
133203
raise wrap_error(e) from e
134204

@@ -154,6 +224,17 @@ def create(cls, *args, **kwargs):
154224
**kwargs,
155225
)
156226

227+
@classmethod
228+
async def acreate(cls, *args, **kwargs):
229+
kwargs = cls.fill_base_args(**kwargs)
230+
return await aadapt(
231+
cls._allm_handler,
232+
cls._cache_data_convert,
233+
cls._update_cache_callback,
234+
*args,
235+
**kwargs,
236+
)
237+
157238

158239
class Audio(openai.Audio):
159240
"""Openai Audio Wrapper
@@ -319,7 +400,11 @@ class Moderation(openai.Moderation, BaseCacheLLM):
319400
@classmethod
320401
def _llm_handler(cls, *llm_args, **llm_kwargs):
321402
try:
322-
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
403+
return (
404+
super().create(*llm_args, **llm_kwargs)
405+
if not cls.llm
406+
else cls.llm(*llm_args, **llm_kwargs)
407+
)
323408
except openai.OpenAIError as e:
324409
raise wrap_error(e) from e
325410

0 commit comments

Comments
 (0)