This repository has been archived by the owner on Mar 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 26
/
mnli.py
115 lines (90 loc) · 3.57 KB
/
mnli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import numpy as np
import pandas as pd
from keras_xlnet.backend import keras
from keras_bert.layers import Extract
from keras_xlnet import PretrainedList, get_pretrained_paths
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI
EPOCH = 5
BATCH_SIZE = 16
SEQ_LEN = 100
MODEL_NAME = 'MNLI.h5'
CLASSES = {
'neutral': 0,
'entailment': 1,
'contradiction': 2,
}
current_path = os.path.dirname(os.path.abspath(__file__))
train_path = os.path.join(current_path, 'train.tsv')
dev_matched_path = os.path.join(current_path, 'dev_matched.tsv')
dev_mismatched_path = os.path.join(current_path, 'dev_mismatched.tsv')
paths = get_pretrained_paths(PretrainedList.en_cased_base)
tokenizer = Tokenizer(paths.vocab)
# Read data
class DataSequence(keras.utils.Sequence):
def __init__(self, x, y):
self.x = x
self.y = y
def __len__(self):
return (len(self.y) + BATCH_SIZE - 1) // BATCH_SIZE
def __getitem__(self, index):
s = slice(index * BATCH_SIZE, (index + 1) * BATCH_SIZE)
return [item[s] for item in self.x], self.y[s]
def generate_sequence(path):
tokens, classes = [], []
df = pd.read_csv(path, sep='\t', error_bad_lines=False)
for _, row in df.iterrows():
text_a, text_b, cls = row['sentence1'], row['sentence2'], row['gold_label']
if not isinstance(text_a, str) or not isinstance(text_b, str) or cls not in CLASSES:
continue
encoded_a, encoded_b = tokenizer.encode(text_a)[:48], tokenizer.encode(text_b)[:49]
encoded = encoded_a + [tokenizer.SYM_SEP] + encoded_b + [tokenizer.SYM_SEP]
encoded = [tokenizer.SYM_PAD] * (SEQ_LEN - 1 - len(encoded)) + encoded + [tokenizer.SYM_CLS]
tokens.append(encoded)
classes.append(CLASSES[cls])
tokens, classes = np.array(tokens), np.array(classes)
segments = np.zeros_like(tokens)
segments[:, -1] = 1
lengths = np.zeros_like(tokens[:, :1])
return DataSequence([tokens, segments, lengths], classes)
current_path = os.path.dirname(os.path.abspath(__file__))
train_seq = generate_sequence(train_path)
dev_matched_seq = generate_sequence(dev_matched_path)
dev_mismatched_seq = generate_sequence(dev_mismatched_path)
# Load pretrained model
model = load_trained_model_from_checkpoint(
config_path=paths.config,
checkpoint_path=paths.model,
batch_size=BATCH_SIZE,
memory_len=0,
target_len=SEQ_LEN,
in_train_phase=False,
attention_type=ATTENTION_TYPE_BI,
)
# Build classification model
last = Extract(index=-1, name='Extract')(model.output)
dense = keras.layers.Dense(units=768, activation='tanh', name='Dense')(last)
dropout = keras.layers.Dropout(rate=0.1, name='Dropout')(dense)
output = keras.layers.Dense(units=3, activation='softmax', name='Softmax')(dropout)
model = keras.models.Model(inputs=model.inputs, outputs=output)
model.summary()
# Fit model
if os.path.exists(MODEL_NAME):
model.load_weights(MODEL_NAME)
model.compile(
optimizer=keras.optimizers.Adam(lr=3e-5),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'],
)
model.fit_generator(
generator=train_seq,
validation_data=dev_matched_seq,
epochs=EPOCH,
callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)],
)
model.save_weights(MODEL_NAME)
# Evaluation
# Use dev set because the results of test set is unknown
for dev_seq in [dev_matched_seq, dev_mismatched_seq]:
results = model.predict_generator(dev_seq, verbose=True).argmax(axis=-1)
print('Accuracy: %.2f' % (100.0 * np.sum(results == dev_seq.y) / len(results)))