-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactored the joint q2ar evaluation script
- Loading branch information
Rowan Zellers
committed
Feb 14, 2019
1 parent
73a2408
commit ae532f6
Showing
2 changed files
with
64 additions
and
86 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
You can use this script to evaluate prediction files (valpreds.npy). Essentially this is needed if you want to, say, | ||
combine answer and rationale predictions. | ||
""" | ||
|
||
import numpy as np | ||
import json | ||
import os | ||
from config import VCR_ANNOTS_DIR | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='Evaluate question -> answer and rationale') | ||
parser.add_argument( | ||
'-answer_preds', | ||
dest='answer_preds', | ||
default='saves/flagship_answer/valpreds.npy', | ||
help='Location of question->answer predictions', | ||
type=str, | ||
) | ||
parser.add_argument( | ||
'-rationale_preds', | ||
dest='rationale_preds', | ||
default='saves/flagship_rationale/valpreds.npy', | ||
help='Location of question+answer->rationale predictions', | ||
type=str, | ||
) | ||
parser.add_argument( | ||
'-split', | ||
dest='split', | ||
default='val', | ||
help='Split you\'re using. Probably you want val.', | ||
type=str, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
answer_preds = np.load(args.answer_preds) | ||
rationale_preds = np.load(args.rationale_preds) | ||
|
||
rationale_labels = [] | ||
answer_labels = [] | ||
|
||
with open(os.path.join(VCR_ANNOTS_DIR, '{}.jsonl'.format(args.split)), 'r') as f: | ||
for l in f: | ||
item = json.loads(l) | ||
answer_labels.append(item['answer_label']) | ||
rationale_labels.append(item['rationale_label']) | ||
|
||
answer_labels = np.array(answer_labels) | ||
rationale_labels = np.array(rationale_labels) | ||
|
||
# Sanity checks | ||
assert answer_preds.shape[0] == answer_labels.size | ||
assert rationale_preds.shape[0] == answer_labels.size | ||
assert answer_preds.shape[1] == 4 | ||
assert rationale_preds.shape[1] == 4 | ||
|
||
answer_hits = answer_preds.argmax(1) == answer_labels | ||
rationale_hits = rationale_preds.argmax(1) == rationale_labels | ||
joint_hits = answer_hits & rationale_hits | ||
|
||
print("Answer acc: {:.3f}".format(np.mean(answer_hits)), flush=True) | ||
print("Rationale acc: {:.3f}".format(np.mean(rationale_hits)), flush=True) | ||
print("Joint acc: {:.3f}".format(np.mean(answer_hits & rationale_hits)), flush=True) |