Skip to content

Commit c10563d

Browse files
committed
Fix the wrong return value of onnx similarity evaluation
Signed-off-by: SimFG <bang.fu@zilliz.com>
1 parent de35ae6 commit c10563d

File tree

7 files changed

+12
-6
lines changed

7 files changed

+12
-6
lines changed

gptcache/adapter/adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
166166
rank,
167167
)
168168
if rank_threshold <= rank:
169-
cache_answers.append((rank, cache_data.answers[0].answer, search_data))
169+
cache_answers.append((float(rank), cache_data.answers[0].answer, search_data))
170170
chat_cache.data_manager.hit_cache_callback(search_data)
171171
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
172172
answers_dict = dict((d[1], d[2]) for d in cache_answers)
@@ -397,7 +397,7 @@ async def aadapt(llm_handler, cache_data_convert, update_cache_callback, *args,
397397
rank,
398398
)
399399
if rank_threshold <= rank:
400-
cache_answers.append((rank, cache_data.answers[0].answer, search_data))
400+
cache_answers.append((float(rank), cache_data.answers[0].answer, search_data))
401401
chat_cache.data_manager.hit_cache_callback(search_data)
402402
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
403403
answers_dict = dict((d[1], d[2]) for d in cache_answers)

gptcache/manager/scalar_data/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from gptcache.manager.scalar_data.mongo import MongoStorage
21
from gptcache.utils import import_sql_client
32
from gptcache.utils.error import NotFoundError
43

@@ -85,6 +84,8 @@ def get(name, **kwargs):
8584
table_len_config=table_len_config,
8685
)
8786
elif name == "mongo":
87+
from gptcache.manager.scalar_data.mongo import MongoStorage
88+
8889
return MongoStorage(
8990
host=kwargs.get("mongo_host", "localhost"),
9091
port=kwargs.get("mongo_port", 27017),

gptcache/processor/post.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import random
22
from typing import List, Any
3+
34
import numpy
45

56
from gptcache.utils import softmax

gptcache/similarity_evaluation/distance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Tuple, Dict, Any
2+
23
from gptcache.similarity_evaluation import SimilarityEvaluation
34

45

gptcache/similarity_evaluation/onnx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import Dict, List, Tuple, Any
2+
23
import numpy as np
4+
5+
from gptcache.similarity_evaluation import SimilarityEvaluation
36
from gptcache.utils import (
47
import_onnxruntime,
58
import_huggingface_hub,
69
import_huggingface,
710
)
8-
from gptcache.similarity_evaluation import SimilarityEvaluation
911

1012
import_onnxruntime()
1113
import_huggingface_hub()
@@ -130,4 +132,4 @@ def inference(self, reference: str, candidates: List[str]) -> np.ndarray:
130132
}
131133
ort_outputs = self.ort_session.run(None, ort_inputs)
132134
scores = ort_outputs[0][:, 1]
133-
return scores
135+
return float(scores[0])

gptcache/utils/softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
def softmax(x: list):
55
x = np.array(x)
6-
assert len(x.shape) == 1, f"Expect to get a shape of (len,) but got {x.shape}."
6+
assert len(x.shape) == 1, f"Expect to get a shape of (len,) but got {x.shape}, x value: {x}."
77
max_val = x.max()
88
e_x = np.exp(x - max_val)
99
return e_x / e_x.sum()

tests/unit_tests/similarity_evaluation/test_evaluation_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def _test_evaluation(evaluation):
1717
candidate_2 = "how old are you?"
1818

1919
score = evaluation.evaluation({"question": query}, {"question": candidate_1})
20+
assert isinstance(score, float), type(score)
2021
assert score > 0.8
2122

2223
score = evaluation.evaluation({"question": query}, {"question": candidate_2})

0 commit comments

Comments
 (0)