@@ -267,13 +267,13 @@ def build_multiple_choice_dialog(self,
267267 dialog = []
268268
269269 # Construct the dialog with system and user parts
270-
270+
271271 s_message = [
272272 {
273273 "role" : "system" ,
274274 "content" : system_message
275275 }
276- ]
276+ ]
277277 u_message = [
278278 {
279279 "role" : "user" ,
@@ -412,7 +412,7 @@ def respond_to_dialogs_batched(self, dialogs, prefixes=None):
412412 return generated_outputs
413413
414414 def aligned_decision_maker (self , question , choices , target_kdmas , incontext = None , n_positive_samples = 5 , n_negative_sampels = 5 , shuffle = True , baseline = False , n_retries = 3 ):
415- """ Executes a decision-making process by simulating a dialog based on positive and negative alignments with specified Knowledge Domain Model Attributes (KDMAs).
415+ """ Executes a decision-making process by simulating a dialog based on positive and negative alignments with specified Knowledge Domain Model Attributes (KDMAs).
416416 It attempts to identify the choice that best aligns with the target attributes, using both positive and negative samples to provide robustness against biases.
417417
418418 Parameters:
@@ -435,7 +435,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, incontext=None
435435 RuntimeError: If any specified KDAMs in `target_kdmas` are not supported by the system.
436436
437437 Notes:
438- This function leverages logging to trace both aligned and misaligned dialogs, only the first of each type is logged for brevity.
438+ This function leverages logging to trace both aligned and misaligned dialogs, only the first of each type is logged for brevity.
439439 """
440440
441441 inference_pairs = []
@@ -507,7 +507,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, incontext=None
507507 answer_text = None
508508 # Ensure an answer was parsed successfully
509509 log .explain ('CHOSEN ANSWER IDX %s %s' , answer_idx , shuffled_choices )
510- assert answer_idx is not None , f'Failed to parse answer index from generated output: { low_response } '
510+ assert answer_idx is not None , f'Failed to parse answer index from generated output: { high_response } '
511511
512512 # Store response details
513513 responses .append ({
@@ -598,7 +598,7 @@ def calculate_votes(responses, choices):
598598 """
599599 choice_votes = [0 ] * len (choices )
600600 for response in responses :
601- # TODO: Make it a choice to switch rather than always do it.
601+ # TODO: Make it a choice to switch rather than always do it.
602602
603603 answer_idx = response ['answer_idx' ]
604604 if answer_idx is None :
@@ -609,14 +609,21 @@ def calculate_votes(responses, choices):
609609 except ValueError :
610610 continue
611611
612- answer_text = response ['answer_text' ]
612+ answer_text = None
613+ if 'answer_text' in response :
614+ answer_text = response ['answer_text' ]
615+ if (isinstance (answer_text , list ) or isinstance (answer_text , tuple )):
616+ if len (answer_text ) > 0 :
617+ answer_text = answer_text [0 ]
618+ else :
619+ answer_text = None
613620 chosen_idx = - 1
614621 potentially_shuffled_choices = choices
615622 if 'shuffle_indices' in response :
616623 potentially_shuffled_choices = [choices [i ] for i in response ['shuffle_indices' ]]
617624
618625 for idx , choice in enumerate (potentially_shuffled_choices ):
619- if choice in answer_text or answer_text in choice :
626+ if answer_text is not None and ( choice in answer_text or answer_text in choice ) :
620627 chosen_idx = idx
621628 break
622629
@@ -628,11 +635,11 @@ def calculate_votes(responses, choices):
628635 else :
629636 log .debug (f'Answer text index equals the parsed answer index. Answer Text Index: { chosen_idx } Answer Index: { answer_idx } .' )
630637
631- if answer_idx >= len (choices ):
638+ if answer_idx < 0 or answer_idx >= len (choices ):
632639 continue
633640
634641 if 'shuffle_indices' in response :
635- answer_idx = response ['shuffle_indices' ][answer_idx ]
642+ answer_idx = response ['shuffle_indices' ][int ( answer_idx ) ]
636643
637644 aligned = response ['aligned' ]
638645
@@ -843,17 +850,17 @@ def correct_json(self, invalid_json, verbose=True):
843850 return None
844851
845852 def run_aligned_decision_maker_with_voting (
846- self ,
847- prompt ,
848- choices ,
849- alignment_target ,
853+ self ,
854+ prompt ,
855+ choices ,
856+ alignment_target ,
850857 incontext = None ,
851- n_positive_samples = 5 ,
852- n_negative_samples = 5 ,
853- baseline = False ,
858+ n_positive_samples = 5 ,
859+ n_negative_samples = 5 ,
860+ baseline = False ,
854861 shuffle = False ):
855- """ Executes a decision-making process with voting based on alignment targets and user-provided choices.
856- This method incorporates a mechanism for evaluating the alignment of choices with a specified target
862+ """ Executes a decision-making process with voting based on alignment targets and user-provided choices.
863+ This method incorporates a mechanism for evaluating the alignment of choices with a specified target
857864 using a set of positive and negative samples.
858865
859866 Parameters:
@@ -877,10 +884,10 @@ def run_aligned_decision_maker_with_voting(
877884 Exception: Captures and logs any exception that occurs during the vote calculation, defaulting choice scores to None if an error occurs.
878885
879886 Notes:
880- This method leverages internal logging to trace the detailed responses and the computation of choice scores.
887+ This method leverages internal logging to trace the detailed responses and the computation of choice scores.
881888 It is essential to ensure proper initialization of the logging and handling mechanisms to capture and utilize
882889 the detailed debug outputs effectively.
883-
890+
884891 """
885892 responses , inference_pairs = self .aligned_decision_maker (
886893 prompt ,
@@ -971,9 +978,9 @@ def format_single_incontext_prompt(self, sample, labels, target_kdma_values):
971978 if kdma_name_map [target ] in score :
972979 # Multiply by 10 to match the rest of the KDMA's score range
973980 dist = abs (score [kdma_name_map [target ]] * 10 - target_kdma_values [target ])
974- else :
981+ else :
975982 dist = float ('inf' ) # If the target attribute is not in the scores, assign an infinite distance
976- dist_to_target .append (dist )
983+ dist_to_target .append (dist )
977984
978985 # Determine the index of the choice with the minimum distance to the target value
979986 correct_answer_idx = np .argmin (dist_to_target )
@@ -1053,7 +1060,7 @@ def __call__(self, sample, target_kdma_values, **kwargs):
10531060 possible_samples_parse = [s ['input' ]['prompt' ] for s in possible_samples ]
10541061
10551062 # Create similarity scores between the in-context dataset and find top-k indices
1056- from bert_score import score
1063+ from bert_score import score
10571064 _ , _ , F1 = score ([prompt ]* len (possible_samples_parse ), possible_samples_parse , lang = 'en' )
10581065 _ , indices = torch .topk (F1 , kwargs ['incontext' ]['number' ])
10591066
@@ -1348,7 +1355,7 @@ def populate_tagging_parameters(self, scenario_state, tagging_action, alignment_
13481355
13491356 parsed_tagging_output = self .attempt_generic_parse ( # noqa
13501357 raw_tagging_response , ['Reasoning' , 'Answer' , 'Tag' ]) # noqa
1351-
1358+
13521359 if parsed_tagging_output is not None :
13531360 if len (untagged_characters ) == 1 :
13541361 log .debug ("** Force selecting only available character" )
0 commit comments