Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
seanie12 committed Apr 29, 2019
1 parent 4f0ccc9 commit 63a0075
Show file tree
Hide file tree
Showing 5 changed files with 655 additions and 77 deletions.
18 changes: 14 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,34 @@
embedding = "./data/embedding.pkl"
word2idx_file = "./data/word2idx.pkl"

model_path = "./save/seq2seq/train_422203728/20_2.68"
train = False
model_path = None
train = True
device = "cuda:1"
use_gpu = True
debug = False
vocab_size = 45000
freeze_embedding = True

num_epochs = 20
max_len = 400
max_length = 400
max_seq_len = 364
max_query_len = 64
num_layers = 2
hidden_size = 300
embedding_size = 300
lr = 0.1

# QA config
qa_lr = 5e-5
gradient_accumulation_steps = 1
warmup_proportion = 0.1
dual_lambda = 0.1

lr = 1e-3
batch_size = 64
dropout = 0.3
max_grad_norm = 5.0


use_tag = True
use_pointer = True
beam_size = 10
Expand Down
27 changes: 25 additions & 2 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ def merge(sequences):
def get_loader(src_file, trg_file, word2idx,
batch_size, use_tag=False, debug=False, shuffle=False):
if use_tag:
dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_len,
dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_seq_len,
word2idx, debug)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn_tag)
else:
dataset = SQuadDataset(src_file, trg_file, config.max_len,
dataset = SQuadDataset(src_file, trg_file, config.max_seq_len,
word2idx, debug)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -565,3 +565,26 @@ def make_conll_format(examples, src_file, trg_file):

src_fw.close()
trg_fw.close()


def split_dev(input_file, dev_file, test_file):
with open(input_file) as f:
input_file = json.load(f)

input_data = input_file["data"]

# split the original SQuAD dev set into new dev / test set
num_total = len(input_data)
num_dev = int(num_total * 0.5)

dev_data = input_data[:num_dev]
test_data = input_data[num_dev:]

dev_dict = {"data": dev_data}
test_dict = {"data": test_data}

with open(dev_file, "w") as f:
json.dump(dev_dict, f)

with open(test_file, "w") as f:
json.dump(test_dict, f)
65 changes: 37 additions & 28 deletions infenrence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from model import Seq2seq
import os
from data_utils import START_TOKEN, END_ID, get_loader, UNK_ID, outputids2words
from squad_utils import read_squad_examples, convert_examples_to_features
import torch
from torch.utils.data import SequentialSampler, DataLoader, TensorDataset
import torch.nn.functional as F
import config
import pickle
Expand Down Expand Up @@ -52,6 +54,25 @@ def __init__(self, model_path, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

def get_data_loader(self, file):
train_examples = read_squad_examples(file, is_training=True, debug=config.debug)
train_features = convert_examples_to_features(train_examples,
tokenizer=self.tokenizer,
max_seq_length=config.max_seq_len,
max_query_length=config.max_query_len,
doc_stride=128,
is_training=True)

all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long)
all_c_lens = torch.sign(torch.sum(all_c_ids, 1)).long()
all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long)

train_data = TensorDataset(all_c_ids, all_c_lens, all_q_ids)
sampler = SequentialSampler(train_data)
train_loader = DataLoader(train_data, sampler=sampler, batch_size=1)

return train_loader

@staticmethod
def sort_hypotheses(hypotheses):
return sorted(hypotheses, key=lambda h: h.avg_log_prob, reverse=True)
Expand All @@ -60,14 +81,9 @@ def decode(self):
pred_fw = open(self.pred_dir, "w")
golden_fw = open(self.golden_dir, "w")
for i, eval_data in enumerate(self.data_loader):
if config.use_tag:
src_seq, ext_src_seq, src_len, trg_seq, \
ext_trg_seq, trg_len, tag_seq, oov_lst = eval_data
else:
src_seq, ext_src_seq, src_len, \
trg_seq, ext_trg_seq, trg_len, oov_lst = eval_data
tag_seq = None
best_question = self.beam_search(src_seq, ext_src_seq, src_len, tag_seq)
c_ids, c_lens, q_ids = eval_data
tag_seq = None
best_question = self.beam_search(c_ids, c_lens, q_ids, tag_seq)
# discard START token
output_indices = [int(idx) for idx in best_question.tokens[1:-1]]
decoded_words = outputids2words(output_indices, self.idx2tok, oov_lst[0])
Expand All @@ -85,18 +101,11 @@ def decode(self):
pred_fw.close()
golden_fw.close()

def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq):
zeros = torch.zeros_like(src_seq)
enc_mask = torch.ByteTensor(src_seq == zeros)
src_len = torch.LongTensor(src_len)
prev_context = torch.zeros(1, 1, 2 * config.hidden_size)
def beam_search(self, src_seq, src_len, trg_seq, tag_seq):

if config.use_gpu:
src_seq = src_seq.to(config.device)
ext_src_seq = ext_src_seq.to(config.device)
_seq = src_seq.to(config.device)
src_len = src_len.to(config.device)
enc_mask = enc_mask.to(config.device)
prev_context = prev_context.to(config.device)

if config.use_tag:
tag_seq = tag_seq.to(config.device)
Expand All @@ -106,12 +115,14 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq):
hypotheses = [Hypothesis(tokens=[self.tok2idx[START_TOKEN]],
log_probs=[0.0],
state=(h[:, 0, :], c[:, 0, :]),
context=prev_context[0]) for _ in range(config.beam_size)]
context=None) for _ in range(config.beam_size)]
# tile enc_outputs, enc_mask for beam search
ext_src_seq = ext_src_seq.repeat(config.beam_size, 1)
ext_src_seq = src_seq.repeat(config.beam_size, 1)
enc_outputs = enc_outputs.repeat(config.beam_size, 1, 1)
zeros = enc_outputs.sum(dim=-1)
enc_mask = (zeros == 0).byte()
enc_features = self.model.decoder.get_encoder_features(enc_outputs)
enc_mask = enc_mask.repeat(config.beam_size, 1)

num_steps = 0
results = []
while num_steps < config.max_decode_step and len(results) < config.beam_size:
Expand All @@ -125,21 +136,20 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq):
# make batch of which size is beam size
all_state_h = []
all_state_c = []
all_context = []
for h in hypotheses:
state_h, state_c = h.state # [num_layers, d]
all_state_h.append(state_h)
all_state_c.append(state_c)
all_context.append(h.context)

prev_h = torch.stack(all_state_h, dim=1) # [num_layers, beam, d]
prev_c = torch.stack(all_state_c, dim=1) # [num_layers, beam, d]
prev_context = torch.stack(all_context, dim=0)
prev_states = (prev_h, prev_c)
# [beam_size, |V|]
logits, states, context_vector = self.model.decoder.decode(prev_y, ext_src_seq,
prev_states, prev_context,
enc_features, enc_mask)
logits, states, = self.model.decoder.decode(prev_y,
ext_src_seq,
prev_states,
enc_features,
enc_mask)
h_state, c_state = states
log_probs = F.log_softmax(logits, dim=1)
top_k_log_probs, top_k_ids \
Expand All @@ -150,12 +160,11 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq):
for i in range(num_orig_hypotheses):
h = hypotheses[i]
state_i = (h_state[:, i, :], c_state[:, i, :])
context_i = context_vector[i]
for j in range(config.beam_size * 2):
new_h = h.extend(token=top_k_ids[i][j].item(),
log_prob=top_k_log_probs[i][j].item(),
state=state_i,
context=context_i)
context=None)
all_hypotheses.append(new_h)

hypotheses = []
Expand Down
Loading

0 comments on commit 63a0075

Please sign in to comment.