@@ -286,18 +286,22 @@ def lazy_indices_regex(
286
286
translation_literal = TRANSLATION_LITERALS [language ]
287
287
# First get indices to predict
288
288
indices = get_prefix (indices_config .prefix_for_extraction , translation_literal )[:len_choices ]
289
- indice_str_re = f"(?P<indices>{ '|' .join ([re .escape (i ) for i in indices ])} )"
289
+ indices_escaped = [re .escape (i ) for i in indices ]
290
+ # We allow both (A) and A
291
+ indices_wrapped = [rf"(?:{ i } |\({ i } \))" for i in indices_escaped ]
292
+ indice_str_re = f"(?P<indices>{ '|' .join (indices_wrapped )} )"
290
293
291
294
# The answer keys are either surrounded with <space>**answer**., or '<space>answer.' or the same without the dot
292
295
full_stop_re = rf"[{ re .escape (translation_literal .full_stop )} \.]"
293
296
comma_re = rf"[{ re .escape (translation_literal .comma )} \,]"
294
297
colon_re = rf"[{ re .escape (translation_literal .colon )} \:]"
295
298
space_re = re .escape (translation_literal .sentence_space )
296
299
297
- answer_prefix_re = rf"(^|{ space_re } )(?:\*\*)?"
300
+ answer_prefix_re = rf"(?: ^|{ space_re } )(?:\*\*)?"
298
301
answer_suffix_re = rf"(?:\*\*)?(?:{ full_stop_re } |{ comma_re } |{ colon_re } |{ space_re } |$)"
299
302
answer_re = f"{ answer_prefix_re } { indice_str_re } { answer_suffix_re } "
300
303
answer_re_start = rf"^(?:\*\*)?{ indice_str_re } { answer_suffix_re } "
304
+ answer_re_line_start = rf"\n(?:\*\*)?{ indice_str_re } { answer_suffix_re } "
301
305
302
306
answer_word = f"(?i:{ translation_literal .answer } )"
303
307
@@ -320,8 +324,10 @@ def lazy_indices_regex(
320
324
(f"{ answer_word } { colon_re } .{{0,50}}?{ answer_re } " , 100 ),
321
325
# Answer word patterns
322
326
(f"{ answer_word } .{{0,50}}?{ answer_re } " , 150 ),
323
- # Start of line patterns
327
+ # Start of the string
324
328
(answer_re_start , 200 ),
329
+ # Start of the line
330
+ (answer_re_line_start , 210 ),
325
331
]
326
332
)
327
333
@@ -490,6 +496,15 @@ def extract_latex(
490
496
return latex_exprs [0 ], latex_strs [0 ]
491
497
492
498
499
+ def extract_indices (
500
+ match : re .Match , target_type : IndicesExtractionConfig , timeout_seconds : int
501
+ ) -> tuple [str | None , str ]:
502
+ def normalize_index (index : str ) -> str :
503
+ return index .replace ("(" , "" ).replace (")" , "" ).strip ()
504
+
505
+ return normalize_index (match .group ("indices" )), normalize_index (match .group ("indices" ))
506
+
507
+
493
508
def extract_match (
494
509
match : re .Match , target_type : ExtractionTarget , timeout_seconds : int
495
510
) -> tuple [Basic | MatrixBase | str | None , str ]:
@@ -510,7 +525,7 @@ def extract_match(
510
525
elif isinstance (target_type , ExprExtractionConfig ):
511
526
return extract_expr (match , timeout_seconds = timeout_seconds )
512
527
elif isinstance (target_type , IndicesExtractionConfig ):
513
- return match . group ( "indices" ), match . group ( "indices" )
528
+ return extract_indices ( match , target_type , timeout_seconds = timeout_seconds )
514
529
515
530
516
531
def extract_target_from_pred (
0 commit comments