@@ -52,7 +52,7 @@ def get_default_model(llm_provider):
5252}
5353
5454
55- def save_individual_results (results , dataset_name : str , llm_provider : str ):
55+ def save_individual_results (results , dataset_name : str , llm_provider : str , retrieval_verbose : bool = False ):
5656 """Save detailed individual results to a file for analysis."""
5757 timestamp = datetime .datetime .now ().strftime ("%Y%m%d_%H%M%S" )
5858 filename = f"individual_results_{ dataset_name } _{ llm_provider } _{ timestamp } .txt"
@@ -145,7 +145,21 @@ def save_individual_results(results, dataset_name: str, llm_provider: str):
145145 f .write (f" Recall: { recall :.3f} \n " )
146146 f .write (f" F1: { f1 :.3f} \n " )
147147 f .write (f" Retrieved memories: { retrieved_memories_count } \n \n " )
148-
148+
149+ # Display retrieved memories content if verbose and available
150+ if retrieval_verbose and 'retrieved_memories_content' in result :
151+ retrieved_memories = result ['retrieved_memories_content' ]
152+ f .write ("🧠 RETRIEVED MEMORIES CONTENT:\n " )
153+ for i , memory in enumerate (retrieved_memories , 1 ):
154+ f .write (f" Memory { i } (Score: { memory ['score' ]:.4f} ):\n " )
155+ content = memory ['content' ]
156+ # Show first 500 characters, add ellipsis if truncated
157+ if len (content ) > 500 :
158+ f .write (f" \" { content [:500 ]} ...\" \n " )
159+ else :
160+ f .write (f" \" { content } \" \n " )
161+ f .write ("\n " )
162+
149163 f .write ("=" * 80 + "\n \n " )
150164
151165 # Summary
@@ -169,27 +183,29 @@ def save_individual_results(results, dataset_name: str, llm_provider: str):
169183def print_benchmark_summary (results , dataset_name ):
170184 """Print detailed benchmark summary with histogram visualization."""
171185
172- # Collect incorrect question IDs
173- incorrect_question_ids = []
186+ # Collect incorrect question IDs with recall flags
187+ incorrect_question_data = []
174188 for result in results .question_results :
175189 if 'is_correct' in result and not result ['is_correct' ]:
176190 question_id = result .get ('question_id' , 'N/A' )
177191 if question_id != 'N/A' :
178- incorrect_question_ids .append (question_id )
192+ recall = result .get ('recall' , 0.0 )
193+ recall_flag = 1 if recall > 0.0 else 0
194+ incorrect_question_data .append ((question_id , recall_flag ))
179195
180- # Write incorrect question IDs to file if any exist
181- if incorrect_question_ids :
196+ # Write incorrect question IDs with recall flags to CSV file if any exist
197+ if incorrect_question_data :
182198 import datetime
183199 timestamp = datetime .datetime .now ().strftime ("%Y%m%d_%H%M%S" )
184200 results_dir = os .path .join (os .path .dirname (__file__ ), 'results' )
185- filename = f"incorrect_questions_{ dataset_name } _{ timestamp } .txt "
201+ filename = f"incorrect_questions_{ dataset_name } _{ timestamp } .csv "
186202 filepath = os .path .join (results_dir , filename )
187203
188204 try :
189205 os .makedirs (results_dir , exist_ok = True )
190206 with open (filepath , 'w' ) as f :
191- for question_id in incorrect_question_ids :
192- f .write (f"{ question_id } \n " )
207+ for question_id , recall_flag in incorrect_question_data :
208+ f .write (f"{ question_id } , { recall_flag } \n " )
193209 logger .info (f"Incorrect question IDs saved to: { filepath } " )
194210 except Exception as e :
195211 logger .error (f"Failed to write incorrect question IDs to file: { e } " )
@@ -205,8 +221,8 @@ def print_benchmark_summary(results, dataset_name):
205221 print (f" Model: { result .get ('model_choice_idx' )} | Correct: { result .get ('correct_choice_idx' )} " )
206222 if 'retrieval_time_ms' in result :
207223 print (f" Retrieval: { result ['retrieval_time_ms' ]:.2f} ms" )
208- # Show retrieval metrics for LME dataset
209- if dataset_name == "lme" and 'precision' in result :
224+ # Show retrieval metrics for LME and MSC datasets
225+ if dataset_name in [ "lme" , "msc" ] and 'precision' in result :
210226 print (f" Retrieval Metrics - P: { result ['precision' ]:.3f} , R: { result ['recall' ]:.3f} , F1: { result ['f1' ]:.3f} " )
211227 else :
212228 print (f"Q{ i + 1 } : { result .get ('question_id' , 'N/A' )} - ⚠️ { result .get ('status' , 'UNKNOWN' )} " )
@@ -229,18 +245,19 @@ def print_benchmark_summary(results, dataset_name):
229245 else :
230246 print (f"⚠️ { results .total_count - results .success_count } questions failed evaluation" )
231247
232- # Retrieval evaluation metrics (LME only )
233- if results .retrieval_metrics_available and dataset_name == "lme" :
248+ # Retrieval evaluation metrics (LME and MSC )
249+ if results .retrieval_metrics_available and dataset_name in [ "lme" , "msc" ] :
234250 print (f"\n 🎯 RETRIEVAL EVALUATION METRICS:" )
235251 print (f" Average Precision: { results .avg_precision :.3f} " )
236252 print (f" Average Recall: { results .avg_recall :.3f} " )
237253 print (f" Average F1 Score: { results .avg_f1 :.3f} " )
238254
239255 # Show incorrect question IDs if any
240- if incorrect_question_ids :
256+ if incorrect_question_data :
257+ incorrect_question_ids = [question_id for question_id , _ in incorrect_question_data ]
241258 print (f"\n ❌ Incorrect Question IDs ({ len (incorrect_question_ids )} total):" )
242259 print (", " .join (incorrect_question_ids ))
243- print (f"💾 Incorrect question IDs also saved to benchmarks/results/" )
260+ print (f"💾 Incorrect question IDs with recall flags saved to benchmarks/results/" )
244261
245262 # Retrieval time statistics
246263 if results .query_times :
@@ -283,20 +300,22 @@ async def main():
283300 parser .add_argument ("--question-types" , nargs = "+" , help = "Filter by question types (LME only)" )
284301 parser .add_argument ("--question-ids-file" , type = str , help = "File containing question IDs to test (one per line)" )
285302 parser .add_argument ("--top-k" , type = int , help = "Override default TOP_K value for memory retrieval" )
286- parser .add_argument ("--llm-provider" , type = str , choices = ["gemini" , "openai" , "anthropic" ],
287- default = "gemini" , help = "LLM provider to use (default: gemini)" )
288-
289- # Parse args partially to get the provider first
290- known_args , _ = parser .parse_known_args ()
291- default_model = get_default_model (known_args .llm_provider )
292-
293- parser .add_argument ("--model" , type = str , default = default_model ,
294- help = f"Model name (default for { known_args .llm_provider } : { default_model } )" )
295- parser .add_argument ("--no-data-loading" , action = "store_true" ,
303+ parser .add_argument ("--llm-provider" , type = str , choices = ["gemini" , "openai" , "anthropic" ],
304+ default = "openai" , help = "LLM provider to use (default: openai)" )
305+ parser .add_argument ("--model" , type = str , help = "Model name (provider-specific default will be used if not specified)" )
306+ parser .add_argument ("--no-data-loading" , action = "store_true" ,
296307 help = "Skip loading haystack data per question (assumes data already loaded)" )
308+ parser .add_argument ("--concurrent" , type = int , default = 1 ,
309+ help = "Number of concurrent evaluations (default: 1)" )
310+ parser .add_argument ("--retrieval-verbose" , action = "store_true" ,
311+ help = "Save and display retrieved memory content in detailed results" )
297312
298313 args = parser .parse_args ()
299314
315+ # Set provider-specific default model if not specified
316+ if not args .model :
317+ args .model = get_default_model (args .llm_provider )
318+
300319 # Validate question-types argument
301320 if args .question_types and args .dataset != "lme" :
302321 logger .warning (f"--question-types is only supported for LME dataset, ignoring for { args .dataset } " )
@@ -307,7 +326,16 @@ async def main():
307326 if args .question_ids_file :
308327 try :
309328 with open (args .question_ids_file , 'r' ) as f :
310- question_ids_from_file = [line .strip () for line in f if line .strip ()]
329+ question_ids_from_file = []
330+ for line in f :
331+ line = line .strip ()
332+ if line :
333+ # Handle CSV format (question_id,recall_flag) or plain text (question_id only)
334+ if ',' in line :
335+ question_id = line .split (',' )[0 ].strip ()
336+ else :
337+ question_id = line
338+ question_ids_from_file .append (question_id )
311339 logger .info (f"Loaded { len (question_ids_from_file )} question IDs from { args .question_ids_file } " )
312340
313341 # When using question-ids-file, override conflicting options
@@ -372,11 +400,13 @@ async def main():
372400 model_name = model_name ,
373401 llm_provider = args .llm_provider ,
374402 skip_data_loading = args .no_data_loading ,
403+ concurrent = args .concurrent ,
404+ retrieval_verbose = args .retrieval_verbose ,
375405 logger = logger
376406 )
377407
378408 # Save individual results to file
379- save_individual_results (results , args .dataset , args .llm_provider )
409+ save_individual_results (results , args .dataset , args .llm_provider , args . retrieval_verbose )
380410
381411 # Print detailed benchmark summary with visualization
382412 print_benchmark_summary (results , args .dataset )
0 commit comments