Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# Example

- [How to run Visual Question Answering with MiniGPT-4](#How-to-run-Visual-Question-Answering-with-MiniGPT-4)
- [How to set the **embedding** function](#How-to-set-the-embedding-function)
- [How to set the **data manager** class](#How-to-set-the-data-manager-class)
- [How to set the **similarity evaluation** interface](#How-to-set-the-similarity-evaluation-interface)
- [Other cache init params](#Other-cache-init-params)
- [How to run with session](#How-to-run-with-session)
- [How to use GPTCache server](#How-to-use-GPTCache-server)
- [Benchmark](#Benchmark)
- [Example](#example)
- [How to run Visual Question Answering with MiniGPT-4](#how-to-run-visual-question-answering-with-minigpt-4)
- [How to set the `embedding` function](#how-to-set-the-embedding-function)
- [Default embedding function](#default-embedding-function)
- [Suitable for embedding methods consisting of a cached storage and vector store](#suitable-for-embedding-methods-consisting-of-a-cached-storage-and-vector-store)
- [Custom embedding](#custom-embedding)
- [How to set the `data manager` class](#how-to-set-the-data-manager-class)
- [How to set the `similarity evaluation` interface](#how-to-set-the-similarity-evaluation-interface)
- [Request cache parameter customization](#request-cache-parameter-customization)
- [How to run with session](#how-to-run-with-session)
- [Run in `with` method](#run-in-with-method)
- [Custom Session](#custom-session)
- [How to use GPTCache server](#how-to-use-gptcache-server)
- [Start server](#start-server)
- [Benchmark](#benchmark)
- [How to use post-process function](#how-to-use-post-process-function)

## How to run Visual Question Answering with MiniGPT-4

Expand Down Expand Up @@ -686,3 +694,24 @@ similarity evaluation func: pair_evaluation (search distance)
| 0.95 | 0.12s | 425 | 25 | 549 |
| 0.9 | 0.23s | 804 | 77 | 118 |
| 0.8 | 0.26s | 904 | 92 | 3 |
## How to use post-process function

You can use the LlmVerifier() function to process the cached answer list after recall. This is similar to `first` or `random_one`, but it will call a LLM to verify whether the recalled question is truly similar to the user's question. You can define your own system prompt to decide under what circumstances the LLM should actively reject. You can also choose a small model to perform the verification step, so only a small additional cost is required.
Example usage:

```python
from gptcache.processor.post import post

# ... (init cache, embedding, data_manager, etc.)

cache.init(
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
post_process_messages_func=LlmVerifier(client=None,
system_prompt=custom_prompt,
model="gpt-3.5-turbo")
)
```

See [processor/post_example.py](./processor/post_example.py) for a runnable example.
47 changes: 47 additions & 0 deletions examples/processor/llm_verifier_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time
import os

from gptcache import cache
from gptcache.adapter import openai
from gptcache.embedding import Onnx
from gptcache.manager import manager_factory
from gptcache.processor.post import LlmVerifier
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation

print("This example demonstrates how to use LLM verification with OpenAI's GPT-3.5 Turbo model.")
cache.set_openai_key()

onnx = Onnx()
data_manager = manager_factory("sqlite,faiss", vector_params={"dimension": onnx.dimension})




custom_prompt = """You are a helpful assistant. Your task is to verify whether the answer is semantically consistent with the question.
If the answer is consistent, respond with "yes". If it is not consistent, respond with "no".
You must only respond in "yes" or "no". """

verifier = LlmVerifier(client=None,
system_prompt=custom_prompt,
model="gpt-3.5-turbo")

cache.init(
embedding_func=onnx.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
post_process_messages_func=verifier
)

question = 'what is github'

for _ in range(3):
start = time.time()
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
messages=[{
'role': 'user',
'content': question
}],
)
print(f"Response: {response['choices'][0]['message']['content']}")
print(f"Time: {round(time.time() - start, 2)}s\n")
110 changes: 63 additions & 47 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from gptcache import cache
from gptcache.processor.post import temperature_softmax
from gptcache.processor.post import temperature_softmax, LlmVerifier
from gptcache.utils.error import NotInitError
from gptcache.utils.log import gptcache_log
from gptcache.utils.time import time_cal
Expand Down Expand Up @@ -189,6 +189,12 @@ def post_process():
scores=[t[0] for t in cache_answers],
temperature=temperature,
)
elif chat_cache.post_process_messages_func is LlmVerifier:
return_message = chat_cache.post_process_messages_func(
messages=[t[1] for t in cache_answers],
scores=[t[0] for t in cache_answers],
original_question=pre_embedding_data
)
else:
return_message = chat_cache.post_process_messages_func(
[t[1] for t in cache_answers]
Expand All @@ -200,29 +206,30 @@ def post_process():
func_name="post_process",
report_func=chat_cache.report.post,
)()
chat_cache.report.hint_cache()
cache_whole_data = answers_dict.get(str(return_message))
if session and cache_whole_data:
chat_cache.data_manager.add_session(
cache_whole_data[2], session.name, pre_embedding_data
)
if cache_whole_data and not chat_cache.config.disable_report:
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
report_cache_data = cache_whole_data[3]
report_search_data = cache_whole_data[2]
chat_cache.data_manager.report_cache(
pre_store_data if isinstance(pre_store_data, str) else "",
report_cache_data.question
if isinstance(report_cache_data.question, str)
else "",
report_search_data[1],
report_cache_data.answers[0].answer
if isinstance(report_cache_data.answers[0].answer, str)
else "",
cache_whole_data[0],
round(time.time() - start_time, 6),
)
return cache_data_convert(return_message)
if return_message is not None:
chat_cache.report.hint_cache()
cache_whole_data = answers_dict.get(str(return_message))
if session and cache_whole_data:
chat_cache.data_manager.add_session(
cache_whole_data[2], session.name, pre_embedding_data
)
if cache_whole_data and not chat_cache.config.disable_report:
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
report_cache_data = cache_whole_data[3]
report_search_data = cache_whole_data[2]
chat_cache.data_manager.report_cache(
pre_store_data if isinstance(pre_store_data, str) else "",
report_cache_data.question
if isinstance(report_cache_data.question, str)
else "",
report_search_data[1],
report_cache_data.answers[0].answer
if isinstance(report_cache_data.answers[0].answer, str)
else "",
cache_whole_data[0],
round(time.time() - start_time, 6),
)
return cache_data_convert(return_message)

next_cache = chat_cache.next_cache
if next_cache:
Expand Down Expand Up @@ -444,6 +451,13 @@ def post_process():
scores=[t[0] for t in cache_answers],
temperature=temperature,
)
elif chat_cache.post_process_messages_func is LlmVerifier:
return_message = chat_cache.post_process_messages_func(
messages=[t[1] for t in cache_answers],
scores=[t[0] for t in cache_answers],
original_question=pre_embedding_data,
temperature=temperature,
)
else:
return_message = chat_cache.post_process_messages_func(
[t[1] for t in cache_answers]
Expand All @@ -455,36 +469,38 @@ def post_process():
func_name="post_process",
report_func=chat_cache.report.post,
)()
chat_cache.report.hint_cache()
cache_whole_data = answers_dict.get(str(return_message))
if session and cache_whole_data:
chat_cache.data_manager.add_session(
cache_whole_data[2], session.name, pre_embedding_data
)
if cache_whole_data:
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
report_cache_data = cache_whole_data[3]
report_search_data = cache_whole_data[2]
chat_cache.data_manager.report_cache(
pre_store_data if isinstance(pre_store_data, str) else "",
report_cache_data.question
if isinstance(report_cache_data.question, str)
else "",
report_search_data[1],
report_cache_data.answers[0].answer
if isinstance(report_cache_data.answers[0].answer, str)
else "",
cache_whole_data[0],
round(time.time() - start_time, 6),
)
return cache_data_convert(return_message)
if return_message is not None:
chat_cache.report.hint_cache()
cache_whole_data = answers_dict.get(str(return_message))
if session and cache_whole_data:
chat_cache.data_manager.add_session(
cache_whole_data[2], session.name, pre_embedding_data
)
if cache_whole_data:
# user_question / cache_question / cache_question_id / cache_answer / similarity / consume time/ time
report_cache_data = cache_whole_data[3]
report_search_data = cache_whole_data[2]
chat_cache.data_manager.report_cache(
pre_store_data if isinstance(pre_store_data, str) else "",
report_cache_data.question
if isinstance(report_cache_data.question, str)
else "",
report_search_data[1],
report_cache_data.answers[0].answer
if isinstance(report_cache_data.answers[0].answer, str)
else "",
cache_whole_data[0],
round(time.time() - start_time, 6),
)
return cache_data_convert(return_message)

next_cache = chat_cache.next_cache
if next_cache:
kwargs["cache_obj"] = next_cache
kwargs["cache_context"] = context
kwargs["cache_skip"] = cache_skip
kwargs["cache_factor"] = cache_factor
kwargs["search_only"] = search_only_flag
llm_data = adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)
Expand Down
116 changes: 116 additions & 0 deletions gptcache/processor/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,119 @@ def temperature_softmax(messages: List[Any], scores: List[float], temperature: f
else:
m_s = list(zip(messages, scores))
return sorted(m_s, key=lambda x: x[1], reverse=True)[0][0]



def llm_semantic_verification(
messages: List[Any],
scores: List[float] = None,
original_question: str = None,
*,
client=None,
system_prompt: str = None,
model: str = "gpt-3.5-turbo",
**kwargs
) -> Any:
"""
Use LLM to verify whether the answer is semantically consistent with the question.
If the answer passes verification, return it; otherwise, return None (to trigger a real LLM call).

:param messages: A list of candidate outputs.
:type messages: List[Any]
:param scores: A list of evaluation scores corresponding to messages.
:type scores: List[float], optional
:param original_question: The original question string.
:type original_question: str, optional
:param client: LLM client object, defaults to None.
:type client: Any, optional
:param system_prompt: System prompt, defaults to None.
:type system_prompt: str, optional
:param model: LLM model name, defaults to "gpt-3.5-turbo".
:type model: str, optional
:param temperature: Sampling temperature, defaults to 0.0.
:type temperature: float, optional
:param kwargs: Other keyword arguments.
:return: The answer if it passes semantic verification, otherwise None.
:rtype: Any

Example:
.. code-block:: python

from gptcache.processor.post import llm_semantic_verification

messages = ["answer1", "answer2"]
scores = [0.9, 0.5]
question = "original question"
answer = llm_semantic_verification(messages, scores, original_question=question)
"""
if not messages or not original_question:
return None
import openai

# Select the answer with the highest score
best_answer = messages[0] if not scores else messages[scores.index(max(scores))]
if client is None:
client = openai
else:
client = client if hasattr(client, 'completions') else client.chat # Ensure client has the correct method for completions
if system_prompt is None:
system_prompt = ("You are a strict semantic verification assistant. "
"… Only answer 'yes' or 'no'. If unsure, answer 'no'.")

try:
resp = client.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user",
"content": f"Question: {original_question}\n"
f"Answer: {best_answer}\n"
f"Does this answer fully match the question? yes/no"}
],
temperature=0,
max_tokens=10
)
verdict = resp.choices[0].message.content.strip().lower()
if verdict in {"yes"}:
return best_answer
except Exception as e:
print("LLM verification failed:", e)



return None


class LlmVerifier:
"""
LlmVerifier is a callable class that wraps the llm_semantic_verification function.
It stores the LLM client, system prompt, and model name for repeated semantic verification tasks.

:param client: LLM client object.
:type client: Any
:param system_prompt: System prompt for the LLM.
:type system_prompt: str
:param model: LLM model name, defaults to "gpt-3.5-turbo".
:type model: str, optional
"""
def __init__(self, client=None, system_prompt=None, model="gpt-3.5-turbo"):
self.client = client
self.system_prompt = system_prompt
self.model = model

def __call__(self, messages, scores=None, original_question=None, **kwargs):
"""
Call the verifier to perform semantic verification using the stored client, prompt, and model.

:param messages: A list of candidate outputs.
:param scores: A list of evaluation scores corresponding to messages.
:param original_question: The original question string.
:param temperature: Sampling temperature.
:param kwargs: Other keyword arguments.
:return: The answer if it passes semantic verification, otherwise None.
"""
return llm_semantic_verification(
messages, scores=scores, original_question=original_question,
client=self.client, system_prompt=self.system_prompt,
model=self.model, **kwargs
)
2 changes: 1 addition & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def import_huggingface_hub():


def import_onnxruntime():
_check_library("onnxruntime", package="onnxruntime==1.14.1")
_check_library("onnxruntime", package="onnxruntime==1.21.1")


def import_faiss():
Expand Down
Loading