55from modelcache .utils .error import NotInitError
66from modelcache .utils .time import time_cal
77from modelcache .processor .pre import multi_analysis
8+ from FlagEmbedding import FlagReranker
89
10+ USE_RERANKER = True # 如果为 True 则启用 reranker,否则使用原有逻辑
911
1012def adapt_query (cache_data_convert , * args , ** kwargs ):
1113 chat_cache = kwargs .pop ("cache_obj" , cache )
@@ -74,53 +76,102 @@ def adapt_query(cache_data_convert, *args, **kwargs):
7476 if rank_pre < rank_threshold :
7577 return
7678
77- for cache_data in cache_data_list :
78- primary_id = cache_data [1 ]
79- ret = chat_cache .data_manager .get_scalar_data (
80- cache_data , extra_param = context .get ("get_scalar_data" , None )
81- )
82- if ret is None :
83- continue
79+ if USE_RERANKER :
80+ reranker = FlagReranker ('BAAI/bge-reranker-v2-m3' , use_fp16 = False )
81+ for cache_data in cache_data_list :
82+ primary_id = cache_data [1 ]
83+ ret = chat_cache .data_manager .get_scalar_data (
84+ cache_data , extra_param = context .get ("get_scalar_data" , None )
85+ )
86+ if ret is None :
87+ continue
8488
85- if "deps" in context and hasattr (ret .question , "deps" ):
86- eval_query_data = {
87- "question" : context ["deps" ][0 ]["data" ],
88- "embedding" : None
89- }
90- eval_cache_data = {
91- "question" : ret .question .deps [0 ].data ,
92- "answer" : ret .answers [0 ].answer ,
93- "search_result" : cache_data ,
94- "embedding" : None ,
95- }
96- else :
97- eval_query_data = {
98- "question" : pre_embedding_data ,
99- "embedding" : embedding_data ,
100- }
89+ rank = reranker .compute_score ([pre_embedding_data , ret [0 ]], normalize = True )
10190
102- eval_cache_data = {
103- "question" : ret [0 ],
104- "answer" : ret [1 ],
105- "search_result" : cache_data ,
106- "embedding" : None
107- }
108- rank = chat_cache .similarity_evaluation .evaluation (
109- eval_query_data ,
110- eval_cache_data ,
111- extra_param = context .get ("evaluation_func" , None ),
112- )
91+ if "deps" in context and hasattr (ret .question , "deps" ):
92+ eval_query_data = {
93+ "question" : context ["deps" ][0 ]["data" ],
94+ "embedding" : None
95+ }
96+ eval_cache_data = {
97+ "question" : ret .question .deps [0 ].data ,
98+ "answer" : ret .answers [0 ].answer ,
99+ "search_result" : cache_data ,
100+ "embedding" : None ,
101+ }
102+ else :
103+ eval_query_data = {
104+ "question" : pre_embedding_data ,
105+ "embedding" : embedding_data ,
106+ }
107+
108+ eval_cache_data = {
109+ "question" : ret [0 ],
110+ "answer" : ret [1 ],
111+ "search_result" : cache_data ,
112+ "embedding" : None
113+ }
114+
115+ if len (pre_embedding_data ) <= 256 :
116+ if rank_threshold <= rank :
117+ cache_answers .append ((rank , ret [1 ]))
118+ cache_questions .append ((rank , ret [0 ]))
119+ cache_ids .append ((rank , primary_id ))
120+ else :
121+ if rank_threshold_long <= rank :
122+ cache_answers .append ((rank , ret [1 ]))
123+ cache_questions .append ((rank , ret [0 ]))
124+ cache_ids .append ((rank , primary_id ))
125+ else :
126+ # 不使用 reranker 时,走原来的逻辑
127+ for cache_data in cache_data_list :
128+ primary_id = cache_data [1 ]
129+ ret = chat_cache .data_manager .get_scalar_data (
130+ cache_data , extra_param = context .get ("get_scalar_data" , None )
131+ )
132+ if ret is None :
133+ continue
134+
135+ if "deps" in context and hasattr (ret .question , "deps" ):
136+ eval_query_data = {
137+ "question" : context ["deps" ][0 ]["data" ],
138+ "embedding" : None
139+ }
140+ eval_cache_data = {
141+ "question" : ret .question .deps [0 ].data ,
142+ "answer" : ret .answers [0 ].answer ,
143+ "search_result" : cache_data ,
144+ "embedding" : None ,
145+ }
146+ else :
147+ eval_query_data = {
148+ "question" : pre_embedding_data ,
149+ "embedding" : embedding_data ,
150+ }
151+
152+ eval_cache_data = {
153+ "question" : ret [0 ],
154+ "answer" : ret [1 ],
155+ "search_result" : cache_data ,
156+ "embedding" : None
157+ }
158+ rank = chat_cache .similarity_evaluation .evaluation (
159+ eval_query_data ,
160+ eval_cache_data ,
161+ extra_param = context .get ("evaluation_func" , None ),
162+ )
163+
164+ if len (pre_embedding_data ) <= 256 :
165+ if rank_threshold <= rank :
166+ cache_answers .append ((rank , ret [1 ]))
167+ cache_questions .append ((rank , ret [0 ]))
168+ cache_ids .append ((rank , primary_id ))
169+ else :
170+ if rank_threshold_long <= rank :
171+ cache_answers .append ((rank , ret [1 ]))
172+ cache_questions .append ((rank , ret [0 ]))
173+ cache_ids .append ((rank , primary_id ))
113174
114- if len (pre_embedding_data ) <= 256 :
115- if rank_threshold <= rank :
116- cache_answers .append ((rank , ret [1 ]))
117- cache_questions .append ((rank , ret [0 ]))
118- cache_ids .append ((rank , primary_id ))
119- else :
120- if rank_threshold_long <= rank :
121- cache_answers .append ((rank , ret [1 ]))
122- cache_questions .append ((rank , ret [0 ]))
123- cache_ids .append ((rank , primary_id ))
124175 cache_answers = sorted (cache_answers , key = lambda x : x [0 ], reverse = True )
125176 cache_questions = sorted (cache_questions , key = lambda x : x [0 ], reverse = True )
126177 cache_ids = sorted (cache_ids , key = lambda x : x [0 ], reverse = True )
@@ -141,4 +192,4 @@ def adapt_query(cache_data_convert, *args, **kwargs):
141192 logging .info ('update_hit_count except, please check!' )
142193
143194 chat_cache .report .hint_cache ()
144- return cache_data_convert (return_message , return_query )
195+ return cache_data_convert (return_message , return_query )
0 commit comments