Skip to content

Commit

Permalink
added changes for MOE in export and download
Browse files Browse the repository at this point in the history
  • Loading branch information
KunalTiwary committed Jul 25, 2024
1 parent 6e03c54 commit 6e16fd4
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 71 deletions.
26 changes: 26 additions & 0 deletions backend/dataset/migrations/0051_auto_20240725_0900.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 3.2.14 on 2024-07-25 09:00

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("dataset", "0050_alter_interaction_no_of_turns"),
]

operations = [
migrations.RemoveField(
model_name="promptanswer",
name="eval_output_likert_score",
),
migrations.AddField(
model_name="promptanswer",
name="prompt_output_pair_id",
field=models.CharField(
help_text="prompt_output_pair_id",
max_length=16,
null=True,
verbose_name="prompt_output_pair_id",
),
),
]
12 changes: 6 additions & 6 deletions backend/dataset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,12 +788,6 @@ class PromptAnswer(DatasetBase):
language = models.CharField(
verbose_name="language", choices=LANG_CHOICES, max_length=15
)
eval_output_likert_score = models.IntegerField(
verbose_name="evaluation_prompt_response_rating",
null=True,
blank=True,
help_text=("Rating of the prompt response"),
)
eval_form_output_json = models.JSONField(
verbose_name="evaluation_form_output",
null=True,
Expand All @@ -806,6 +800,12 @@ class PromptAnswer(DatasetBase):
blank=True,
help_text=("Time taken to complete the prompt response"),
)
prompt_output_pair_id = models.CharField(
verbose_name="prompt_output_pair_id",
max_length=16,
help_text=("prompt_output_pair_id"),
null=True,
)

def __str__(self):
return str(self.id)
Expand Down
2 changes: 1 addition & 1 deletion backend/projects/project_registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ModelOutputEvaluation:
id: interaction_id
annotations:
- eval_form_output_json
- eval_output_likert_score
- prompt_output_pair_id
- eval_time_taken
- model
- prompt
Expand Down
2 changes: 1 addition & 1 deletion backend/projects/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def export_project_new_record(
)

elif project_type == "ModelInteractionEvaluation":
item_data_list = get_attributes_for_ModelInteractionEvaluation(task, True)
item_data_list = get_attributes_for_ModelInteractionEvaluation(task)
for item in range(len(item_data_list)):
data_item = dataset_model()
data_item.instance_id = export_dataset_instance
Expand Down
88 changes: 34 additions & 54 deletions backend/projects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,8 @@ def ocr_word_count(annotation_result):
return word_count


# function to obtain the correct annotation object
def get_correct_annotation_obj(task):
annotation = Annotation.objects.filter(task=task)
correct_ann_obj = annotation[0]
if len(annotation) == 2:
for ann in annotation:
if ann.annotation_type == REVIEWER_ANNOTATION:
correct_ann_obj = ann
elif len(annotation) == 3:
for ann in annotation:
if ann.annotation_type == SUPER_CHECKER_ANNOTATION:
correct_ann_obj = ann
return correct_ann_obj


def get_attributes_for_IDC(project, task):
correct_ann_obj = get_correct_annotation_obj(task)
correct_ann_obj = task.correct_annotation
result_dict = {
"interactions_json": correct_ann_obj.result,
"language": project.tgt_language,
Expand All @@ -264,55 +249,50 @@ def get_prompt_output_by_id(prompt_output_pair_id, task_data_dict):
return None, None


def get_attributes_for_ModelInteractionEvaluation(task, correction_annotation_present):
def get_attributes_for_ModelInteractionEvaluation(task):
res = []
if correction_annotation_present:
correct_ann_obj = get_correct_annotation_obj(task)
if task.correct_annotation:
correct_ann_obj = task.correct_annotation
annotation_result_json = correct_ann_obj.result
interaction = Interaction.objects.get(id=task.data["interaction_id"])
try:
prompt_output_pair = get_prompt_output_by_id(
annotation_result_json["prompt_output_pair_id"], task.data
)
except Exception as e:
prompt_output_pair = ["", ""]
else:
annotation_result_json = task["annotations"][0]["result"]
annotation_result_json = Annotation.objects.filter(task=task)[0].result
annotation_result_json = (
json.loads(annotation_result_json)
if isinstance(annotation_result_json, str)
else annotation_result_json
)
interaction = Interaction.objects.get(id=task["data"]["interaction_id"])
interaction = Interaction.objects.get(id=task.data["interaction_id"])
for a in annotation_result_json:
try:
prompt_output_pair = [
annotation_result_json["prompt"],
annotation_result_json["output"],
]
prompt_output_pair = get_prompt_output_by_id(
a["prompt_output_pair_id"], task.data
)
except Exception as e:
prompt_output_pair = ["", ""]

model = interaction.model
language = interaction.language

temp_attributes_obj = {
"interaction_id": interaction,
"model": model,
"language": language,
"prompt": prompt_output_pair[0],
"output": prompt_output_pair[1],
}
if "questions_response" in annotation_result_json:
temp_attributes_obj["eval_form_output_json"] = annotation_result_json[
"questions_response"
]
if "rating" in annotation_result_json:
temp_attributes_obj["eval_output_likert_score"] = annotation_result_json[
"rating"
]
if "time_taken" in annotation_result_json:
temp_attributes_obj["eval_time_taken"] = annotation_result_json["time_taken"]
res.append(temp_attributes_obj)
try:
prompt_output_pair = [
annotation_result_json["prompt"],
annotation_result_json["output"],
]
except Exception as e:
prompt_output_pair = ["", ""]
model = interaction.model
language = interaction.language

temp_attributes_obj = {
"interaction_id": interaction,
"model": model,
"language": language,
"prompt": prompt_output_pair[0],
"output": prompt_output_pair[1],
}
if "prompt_output_pair_id" in a:
temp_attributes_obj["prompt_output_pair_id"] = a["prompt_output_pair_id"]
if "questions_response" in a:
temp_attributes_obj["eval_form_output_json"] = a["questions_response"]
if "time_taken" in a:
temp_attributes_obj["eval_time_taken"] = a["time_taken"]
res.append(temp_attributes_obj)

return res

Expand Down
9 changes: 0 additions & 9 deletions backend/projects/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3968,15 +3968,6 @@ def download(self, request, pk=None, *args, **kwargs):
)
task["data"]["interactions_json"] = annotation_result
del task["annotations"]
elif dataset_type == "Interaction":
for task in tasks_list:
item_data_list = get_attributes_for_ModelInteractionEvaluation(
task, False
)
for it in item_data_list:
for key, value in it.items():
task["data"][key] = value
del task["annotations"]
return DataExport.generate_export_file(project, tasks_list, export_type)
except Project.DoesNotExist:
ret_dict = {"message": "Project does not exist!"}
Expand Down

0 comments on commit 6e16fd4

Please sign in to comment.