Skip to content

Commit

Permalink
ROCStories demo
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 11, 2018
1 parent 1eca568 commit f7c1330
Show file tree
Hide file tree
Showing 19 changed files with 40,859 additions and 0 deletions.
18 changes: 18 additions & 0 deletions analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import json
import numpy as np
import pandas as pd

from sklearn.metrics import accuracy_score

from datasets import _rocstories

def rocstories(data_dir, pred_path, log_path):
preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist()
_, _, _, labels = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
test_accuracy = accuracy_score(labels, preds)*100.
logs = [json.loads(line) for line in open(log_path)][1:]
best_validation_index = np.argmax([log['va_acc'] for log in logs])
valid_accuracy = logs[best_validation_index]['va_acc']
print('ROCStories Valid Accuracy: %.2f'%(valid_accuracy))
print('ROCStories Test Accuracy: %.2f'%(test_accuracy))
51 changes: 51 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import csv
import numpy as np

from tqdm import tqdm

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

seed = 3535999445

def _rocstories(path):
with open(path) as f:
f = csv.reader(f)
st = []
ct1 = []
ct2 = []
y = []
for i, line in enumerate(tqdm(list(f), ncols=80, leave=False)):
if i > 0:
s = ' '.join(line[1:5])
c1 = line[5]
c2 = line[6]
st.append(s)
ct1.append(c1)
ct2.append(c2)
y.append(int(line[-1])-1)
return st, ct1, ct2, y

def rocstories(data_dir, n_train=1497, n_valid=374):
storys, comps1, comps2, ys = _rocstories(os.path.join(data_dir, 'cloze_test_val__spring2016 - cloze_test_ALL_val.csv'))
teX1, teX2, teX3, _ = _rocstories(os.path.join(data_dir, 'cloze_test_test__spring2016 - cloze_test_ALL_test.csv'))
tr_storys, va_storys, tr_comps1, va_comps1, tr_comps2, va_comps2, tr_ys, va_ys = train_test_split(storys, comps1, comps2, ys, test_size=n_valid, random_state=seed)
trX1, trX2, trX3 = [], [], []
trY = []
for s, c1, c2, y in zip(tr_storys, tr_comps1, tr_comps2, tr_ys):
trX1.append(s)
trX2.append(c1)
trX3.append(c2)
trY.append(y)

vaX1, vaX2, vaX3 = [], [], []
vaY = []
for s, c1, c2, y in zip(va_storys, va_comps1, va_comps2, va_ys):
vaX1.append(s)
vaX2.append(c1)
vaX3.append(c2)
vaY.append(y)
trY = np.asarray(trY, dtype=np.int32)
vaY = np.asarray(vaY, dtype=np.int32)
return (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3)
Loading

0 comments on commit f7c1330

Please sign in to comment.