Skip to content

Commit 7d94a80

Browse files
committed
Additional error checking
1 parent 7202196 commit 7d94a80

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

align_system/algorithms/llama_2_single_kdma_adm.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,13 @@ def build_multiple_choice_dialog(self,
267267
dialog = []
268268

269269
# Construct the dialog with system and user parts
270-
270+
271271
s_message = [
272272
{
273273
"role": "system",
274274
"content": system_message
275275
}
276-
]
276+
]
277277
u_message = [
278278
{
279279
"role": "user",
@@ -412,7 +412,7 @@ def respond_to_dialogs_batched(self, dialogs, prefixes=None):
412412
return generated_outputs
413413

414414
def aligned_decision_maker(self, question, choices, target_kdmas, incontext=None, n_positive_samples=5, n_negative_sampels=5, shuffle=True, baseline=False, n_retries=3):
415-
""" Executes a decision-making process by simulating a dialog based on positive and negative alignments with specified Knowledge Domain Model Attributes (KDMAs).
415+
""" Executes a decision-making process by simulating a dialog based on positive and negative alignments with specified Knowledge Domain Model Attributes (KDMAs).
416416
It attempts to identify the choice that best aligns with the target attributes, using both positive and negative samples to provide robustness against biases.
417417
418418
Parameters:
@@ -435,7 +435,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, incontext=None
435435
RuntimeError: If any specified KDAMs in `target_kdmas` are not supported by the system.
436436
437437
Notes:
438-
This function leverages logging to trace both aligned and misaligned dialogs, only the first of each type is logged for brevity.
438+
This function leverages logging to trace both aligned and misaligned dialogs, only the first of each type is logged for brevity.
439439
"""
440440

441441
inference_pairs = []
@@ -507,7 +507,7 @@ def aligned_decision_maker(self, question, choices, target_kdmas, incontext=None
507507
answer_text = None
508508
# Ensure an answer was parsed successfully
509509
log.explain('CHOSEN ANSWER IDX %s %s', answer_idx, shuffled_choices)
510-
assert answer_idx is not None, f'Failed to parse answer index from generated output: {low_response}'
510+
assert answer_idx is not None, f'Failed to parse answer index from generated output: {high_response}'
511511

512512
# Store response details
513513
responses.append({
@@ -598,7 +598,7 @@ def calculate_votes(responses, choices):
598598
"""
599599
choice_votes = [0] * len(choices)
600600
for response in responses:
601-
# TODO: Make it a choice to switch rather than always do it.
601+
# TODO: Make it a choice to switch rather than always do it.
602602

603603
answer_idx = response['answer_idx']
604604
if answer_idx is None:
@@ -609,14 +609,21 @@ def calculate_votes(responses, choices):
609609
except ValueError:
610610
continue
611611

612-
answer_text = response['answer_text']
612+
answer_text = None
613+
if 'answer_text' in response:
614+
answer_text = response['answer_text']
615+
if (isinstance(answer_text, list) or isinstance(answer_text, tuple)):
616+
if len(answer_text) > 0:
617+
answer_text = answer_text[0]
618+
else:
619+
answer_text = None
613620
chosen_idx = -1
614621
potentially_shuffled_choices = choices
615622
if 'shuffle_indices' in response:
616623
potentially_shuffled_choices = [choices[i] for i in response['shuffle_indices']]
617624

618625
for idx, choice in enumerate(potentially_shuffled_choices):
619-
if choice in answer_text or answer_text in choice:
626+
if answer_text is not None and (choice in answer_text or answer_text in choice):
620627
chosen_idx = idx
621628
break
622629

@@ -628,11 +635,11 @@ def calculate_votes(responses, choices):
628635
else:
629636
log.debug(f'Answer text index equals the parsed answer index. Answer Text Index: {chosen_idx} Answer Index: {answer_idx}.')
630637

631-
if answer_idx >= len(choices):
638+
if answer_idx < 0 or answer_idx >= len(choices):
632639
continue
633640

634641
if 'shuffle_indices' in response:
635-
answer_idx = response['shuffle_indices'][answer_idx]
642+
answer_idx = response['shuffle_indices'][int(answer_idx)]
636643

637644
aligned = response['aligned']
638645

@@ -843,17 +850,17 @@ def correct_json(self, invalid_json, verbose=True):
843850
return None
844851

845852
def run_aligned_decision_maker_with_voting(
846-
self,
847-
prompt,
848-
choices,
849-
alignment_target,
853+
self,
854+
prompt,
855+
choices,
856+
alignment_target,
850857
incontext= None,
851-
n_positive_samples=5,
852-
n_negative_samples=5,
853-
baseline=False,
858+
n_positive_samples=5,
859+
n_negative_samples=5,
860+
baseline=False,
854861
shuffle=False):
855-
""" Executes a decision-making process with voting based on alignment targets and user-provided choices.
856-
This method incorporates a mechanism for evaluating the alignment of choices with a specified target
862+
""" Executes a decision-making process with voting based on alignment targets and user-provided choices.
863+
This method incorporates a mechanism for evaluating the alignment of choices with a specified target
857864
using a set of positive and negative samples.
858865
859866
Parameters:
@@ -877,10 +884,10 @@ def run_aligned_decision_maker_with_voting(
877884
Exception: Captures and logs any exception that occurs during the vote calculation, defaulting choice scores to None if an error occurs.
878885
879886
Notes:
880-
This method leverages internal logging to trace the detailed responses and the computation of choice scores.
887+
This method leverages internal logging to trace the detailed responses and the computation of choice scores.
881888
It is essential to ensure proper initialization of the logging and handling mechanisms to capture and utilize
882889
the detailed debug outputs effectively.
883-
890+
884891
"""
885892
responses, inference_pairs = self.aligned_decision_maker(
886893
prompt,
@@ -971,9 +978,9 @@ def format_single_incontext_prompt(self, sample, labels, target_kdma_values):
971978
if kdma_name_map[target] in score:
972979
# Multiply by 10 to match the rest of the KDMA's score range
973980
dist = abs(score[kdma_name_map[target]] * 10 - target_kdma_values[target])
974-
else:
981+
else:
975982
dist = float('inf') # If the target attribute is not in the scores, assign an infinite distance
976-
dist_to_target.append(dist)
983+
dist_to_target.append(dist)
977984

978985
# Determine the index of the choice with the minimum distance to the target value
979986
correct_answer_idx = np.argmin(dist_to_target)
@@ -1053,7 +1060,7 @@ def __call__(self, sample, target_kdma_values, **kwargs):
10531060
possible_samples_parse = [s['input']['prompt'] for s in possible_samples]
10541061

10551062
# Create similarity scores between the in-context dataset and find top-k indices
1056-
from bert_score import score
1063+
from bert_score import score
10571064
_, _, F1 = score([prompt]*len(possible_samples_parse), possible_samples_parse, lang='en')
10581065
_, indices = torch.topk(F1, kwargs['incontext']['number'])
10591066

@@ -1348,7 +1355,7 @@ def populate_tagging_parameters(self, scenario_state, tagging_action, alignment_
13481355

13491356
parsed_tagging_output = self.attempt_generic_parse( # noqa
13501357
raw_tagging_response, ['Reasoning', 'Answer', 'Tag']) # noqa
1351-
1358+
13521359
if parsed_tagging_output is not None:
13531360
if len(untagged_characters) == 1:
13541361
log.debug("** Force selecting only available character")

0 commit comments

Comments
 (0)