Skip to content

Commit

Permalink
Merge pull request huggingface#31 from stevezheng23/dev/zheng/coqa
Browse files Browse the repository at this point in the history
update coqa ensemble runner & data processing pipeline
  • Loading branch information
stevezheng23 authored Nov 14, 2019
2 parents 161edaf + 5fd04a6 commit b5188b8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
6 changes: 3 additions & 3 deletions examples/run_coqa_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@

from transformers import AdamW, WarmupLinearSchedule

from utils_coqa_kd import (read_coqa_examples, convert_examples_to_features,
RawResult, write_predictions, write_predictions_v2,
RawResultExtended, write_predictions_extended, InputFeatures)
from utils_coqa_ensemble import (read_coqa_examples, convert_examples_to_features,
RawResult, write_predictions, write_predictions_v2,
RawResultExtended, write_predictions_extended, InputFeatures)

# The follwing import is the official CoQA evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
Expand Down
7 changes: 3 additions & 4 deletions examples/utils_coqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,17 +1291,16 @@ def write_predictions_v2(all_examples, all_features, all_results, n_best_size,
final_text = pred.answer_type

if final_text in ["option_a", "option_b"]:
norm_question_tokens = CoQAEvaluator.normalize_answer(example.question_text).split(" ")
question_text = example.question_text.split("<q>")[-1]
norm_question_tokens = CoQAEvaluator.normalize_answer(question_text).split(" ")
if "or" in norm_question_tokens:
index = norm_question_tokens.index("or")
if index-1 >= 0 and index+1 < len(norm_question_tokens):
if final_text == "option_a":
final_text = norm_question_tokens[index-1]
if final_text == "option_b":
final_text = norm_question_tokens[index+1]

seen_predictions[final_text] = True


nbest.append(
_NbestPrediction(
text=final_text,
Expand Down

0 comments on commit b5188b8

Please sign in to comment.