From 63a007596468a4a3ec3e61195330b69a29a7224c Mon Sep 17 00:00:00 2001 From: seanie12 Date: Mon, 29 Apr 2019 15:16:21 +0900 Subject: [PATCH] add files --- config.py | 18 ++- data_utils.py | 27 +++- infenrence.py | 65 +++++---- model.py | 231 ++++++++++++++++++++++++----- trainer.py | 391 ++++++++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 655 insertions(+), 77 deletions(-) diff --git a/config.py b/config.py index 74150fc..a5143ca 100644 --- a/config.py +++ b/config.py @@ -11,8 +11,8 @@ embedding = "./data/embedding.pkl" word2idx_file = "./data/word2idx.pkl" -model_path = "./save/seq2seq/train_422203728/20_2.68" -train = False +model_path = None +train = True device = "cuda:1" use_gpu = True debug = False @@ -20,15 +20,25 @@ freeze_embedding = True num_epochs = 20 -max_len = 400 +max_length = 400 +max_seq_len = 364 +max_query_len = 64 num_layers = 2 hidden_size = 300 embedding_size = 300 -lr = 0.1 + +# QA config +qa_lr = 5e-5 +gradient_accumulation_steps = 1 +warmup_proportion = 0.1 +dual_lambda = 0.1 + +lr = 1e-3 batch_size = 64 dropout = 0.3 max_grad_norm = 5.0 + use_tag = True use_pointer = True beam_size = 10 diff --git a/data_utils.py b/data_utils.py index 20a6d51..2e28b53 100644 --- a/data_utils.py +++ b/data_utils.py @@ -249,14 +249,14 @@ def merge(sequences): def get_loader(src_file, trg_file, word2idx, batch_size, use_tag=False, debug=False, shuffle=False): if use_tag: - dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_len, + dataset = SQuadDatasetWithTag(src_file, trg_file, config.max_seq_len, word2idx, debug) dataloader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn_tag) else: - dataset = SQuadDataset(src_file, trg_file, config.max_len, + dataset = SQuadDataset(src_file, trg_file, config.max_seq_len, word2idx, debug) dataloader = data.DataLoader(dataset=dataset, batch_size=batch_size, @@ -565,3 +565,26 @@ def make_conll_format(examples, src_file, trg_file): src_fw.close() trg_fw.close() + + +def split_dev(input_file, dev_file, test_file): + with open(input_file) as f: + input_file = json.load(f) + + input_data = input_file["data"] + + # split the original SQuAD dev set into new dev / test set + num_total = len(input_data) + num_dev = int(num_total * 0.5) + + dev_data = input_data[:num_dev] + test_data = input_data[num_dev:] + + dev_dict = {"data": dev_data} + test_dict = {"data": test_data} + + with open(dev_file, "w") as f: + json.dump(dev_dict, f) + + with open(test_file, "w") as f: + json.dump(test_dict, f) diff --git a/infenrence.py b/infenrence.py index 2bff9bc..e489130 100644 --- a/infenrence.py +++ b/infenrence.py @@ -1,7 +1,9 @@ from model import Seq2seq import os from data_utils import START_TOKEN, END_ID, get_loader, UNK_ID, outputids2words +from squad_utils import read_squad_examples, convert_examples_to_features import torch +from torch.utils.data import SequentialSampler, DataLoader, TensorDataset import torch.nn.functional as F import config import pickle @@ -52,6 +54,25 @@ def __init__(self, model_path, output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) + def get_data_loader(self, file): + train_examples = read_squad_examples(file, is_training=True, debug=config.debug) + train_features = convert_examples_to_features(train_examples, + tokenizer=self.tokenizer, + max_seq_length=config.max_seq_len, + max_query_length=config.max_query_len, + doc_stride=128, + is_training=True) + + all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) + all_c_lens = torch.sign(torch.sum(all_c_ids, 1)).long() + all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long) + + train_data = TensorDataset(all_c_ids, all_c_lens, all_q_ids) + sampler = SequentialSampler(train_data) + train_loader = DataLoader(train_data, sampler=sampler, batch_size=1) + + return train_loader + @staticmethod def sort_hypotheses(hypotheses): return sorted(hypotheses, key=lambda h: h.avg_log_prob, reverse=True) @@ -60,14 +81,9 @@ def decode(self): pred_fw = open(self.pred_dir, "w") golden_fw = open(self.golden_dir, "w") for i, eval_data in enumerate(self.data_loader): - if config.use_tag: - src_seq, ext_src_seq, src_len, trg_seq, \ - ext_trg_seq, trg_len, tag_seq, oov_lst = eval_data - else: - src_seq, ext_src_seq, src_len, \ - trg_seq, ext_trg_seq, trg_len, oov_lst = eval_data - tag_seq = None - best_question = self.beam_search(src_seq, ext_src_seq, src_len, tag_seq) + c_ids, c_lens, q_ids = eval_data + tag_seq = None + best_question = self.beam_search(c_ids, c_lens, q_ids, tag_seq) # discard START token output_indices = [int(idx) for idx in best_question.tokens[1:-1]] decoded_words = outputids2words(output_indices, self.idx2tok, oov_lst[0]) @@ -85,18 +101,11 @@ def decode(self): pred_fw.close() golden_fw.close() - def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq): - zeros = torch.zeros_like(src_seq) - enc_mask = torch.ByteTensor(src_seq == zeros) - src_len = torch.LongTensor(src_len) - prev_context = torch.zeros(1, 1, 2 * config.hidden_size) + def beam_search(self, src_seq, src_len, trg_seq, tag_seq): if config.use_gpu: - src_seq = src_seq.to(config.device) - ext_src_seq = ext_src_seq.to(config.device) + _seq = src_seq.to(config.device) src_len = src_len.to(config.device) - enc_mask = enc_mask.to(config.device) - prev_context = prev_context.to(config.device) if config.use_tag: tag_seq = tag_seq.to(config.device) @@ -106,12 +115,14 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq): hypotheses = [Hypothesis(tokens=[self.tok2idx[START_TOKEN]], log_probs=[0.0], state=(h[:, 0, :], c[:, 0, :]), - context=prev_context[0]) for _ in range(config.beam_size)] + context=None) for _ in range(config.beam_size)] # tile enc_outputs, enc_mask for beam search - ext_src_seq = ext_src_seq.repeat(config.beam_size, 1) + ext_src_seq = src_seq.repeat(config.beam_size, 1) enc_outputs = enc_outputs.repeat(config.beam_size, 1, 1) + zeros = enc_outputs.sum(dim=-1) + enc_mask = (zeros == 0).byte() enc_features = self.model.decoder.get_encoder_features(enc_outputs) - enc_mask = enc_mask.repeat(config.beam_size, 1) + num_steps = 0 results = [] while num_steps < config.max_decode_step and len(results) < config.beam_size: @@ -125,21 +136,20 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq): # make batch of which size is beam size all_state_h = [] all_state_c = [] - all_context = [] for h in hypotheses: state_h, state_c = h.state # [num_layers, d] all_state_h.append(state_h) all_state_c.append(state_c) - all_context.append(h.context) prev_h = torch.stack(all_state_h, dim=1) # [num_layers, beam, d] prev_c = torch.stack(all_state_c, dim=1) # [num_layers, beam, d] - prev_context = torch.stack(all_context, dim=0) prev_states = (prev_h, prev_c) # [beam_size, |V|] - logits, states, context_vector = self.model.decoder.decode(prev_y, ext_src_seq, - prev_states, prev_context, - enc_features, enc_mask) + logits, states, = self.model.decoder.decode(prev_y, + ext_src_seq, + prev_states, + enc_features, + enc_mask) h_state, c_state = states log_probs = F.log_softmax(logits, dim=1) top_k_log_probs, top_k_ids \ @@ -150,12 +160,11 @@ def beam_search(self, src_seq, ext_src_seq, src_len, tag_seq): for i in range(num_orig_hypotheses): h = hypotheses[i] state_i = (h_state[:, i, :], c_state[:, i, :]) - context_i = context_vector[i] for j in range(config.beam_size * 2): new_h = h.extend(token=top_k_ids[i][j].item(), log_prob=top_k_log_probs[i][j].item(), state=state_i, - context=context_i) + context=None) all_hypotheses.append(new_h) hypotheses = [] diff --git a/model.py b/model.py index 690e6b6..cee82b2 100644 --- a/model.py +++ b/model.py @@ -5,31 +5,39 @@ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence from torch_scatter import scatter_max from data_utils import UNK_ID +from pytorch_pretrained_bert import BertForQuestionAnswering INF = 1e12 class Encoder(nn.Module): - def __init__(self, embeddings, vocab_size, embedding_size, hidden_size, num_layers, dropout): + def __init__(self, embeddings, hidden_size, num_layers, dropout, use_tag): super(Encoder, self).__init__() + vocab_size, embedding_size = embeddings.size() + self.use_tag = use_tag self.embedding = nn.Embedding(vocab_size, embedding_size) - if config.use_tag: + if use_tag: self.tag_embedding = nn.Embedding(3, 3) lstm_input_size = embedding_size + 3 else: lstm_input_size = embedding_size + self.embedding = nn.Embedding(vocab_size, embedding_size) + if embeddings is not None: - self.embedding = nn.Embedding(vocab_size, embedding_size). \ - from_pretrained(embeddings, freeze=config.freeze_embedding) + if "FloatTensor" in embeddings.type(): + self.embedding.from_pretrained(embeddings, freeze=True) + else: + self.embedding.weight = embeddings + self.embedding.requires_grad = False self.num_layers = num_layers if self.num_layers == 1: dropout = 0.0 self.lstm = nn.LSTM(lstm_input_size, hidden_size, dropout=dropout, num_layers=num_layers, bidirectional=True, batch_first=True) - self.linear_trans = nn.Linear(2 * hidden_size, 2 * hidden_size) + self.linear_trans = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False) self.update_layer = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) self.gate = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) @@ -49,17 +57,21 @@ def gated_self_attn(self, queries, memories, mask): return updated_output def forward(self, src_seq, src_len, tag_seq): + total_length = src_seq.size(1) embedded = self.embedding(src_seq) - if config.use_tag and tag_seq is not None: + if self.use_tag and tag_seq is not None: tag_embedded = self.tag_embedding(tag_seq) embedded = torch.cat((embedded, tag_embedded), dim=2) packed = pack_padded_sequence(embedded, src_len, batch_first=True) + self.lstm.flatten_parameters() outputs, states = self.lstm(packed) # states : tuple of [4, b, d] - outputs, _ = pad_packed_sequence(outputs, batch_first=True) # [b, t, d] + outputs, _ = pad_packed_sequence(outputs, batch_first=True, + total_length=total_length) # [b, t, d] h, c = states # self attention - mask = (src_seq == 0).byte() + zeros = outputs.sum(dim=-1) + mask = (zeros == 0).byte() memories = self.linear_trans(outputs) outputs = self.gated_self_attn(outputs, memories, mask) @@ -75,21 +87,27 @@ def forward(self, src_seq, src_len, tag_seq): class Decoder(nn.Module): - def __init__(self, embeddings, vocab_size, embedding_size, hidden_size, num_layers, dropout): + def __init__(self, embeddings, hidden_size, num_layers, dropout): super(Decoder, self).__init__() + vocab_size, embedding_size = embeddings.size() self.vocab_size = vocab_size + self.embedding = nn.Embedding(vocab_size, embedding_size) if embeddings is not None: - self.embedding = nn.Embedding(vocab_size, embedding_size). \ - from_pretrained(embeddings, freeze=config.freeze_embedding) + self.embedding = nn.Embedding(vocab_size, embedding_size) + + if "FloatTensor" in embeddings.type(): + self.embedding.from_pretrained(embeddings, freeze=True) + else: + self.embedding.weight = embeddings + self.embedding.requires_grad = False if num_layers == 1: dropout = 0.0 self.encoder_trans = nn.Linear(hidden_size, hidden_size) - self.reduce_layer = nn.Linear(embedding_size + hidden_size, embedding_size) self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, num_layers=num_layers, bidirectional=False, dropout=dropout) - self.concat_layer = nn.Linear(2 * hidden_size, hidden_size) + self.concat_layer = nn.Linear(2 * hidden_size, hidden_size, bias=False) self.logit_layer = nn.Linear(hidden_size, vocab_size) @staticmethod @@ -105,23 +123,23 @@ def attention(query, memories, mask): def get_encoder_features(self, encoder_outputs): return self.encoder_trans(encoder_outputs) - def forward(self, trg_seq, ext_src_seq, init_states, encoder_outputs, encoder_mask): + def forward(self, trg_seq, ext_src_seq, init_states, encoder_outputs): # trg_seq : [b,t] # init_states : [2,b,d] # encoder_outputs : [b,t,d] # init_states : a tuple of [2, b, d] batch_size, max_len = trg_seq.size() - hidden_size = encoder_outputs.size(-1) + zeros = encoder_outputs.sum(dim=-1) + encoder_mask = (zeros == 0).byte() memories = self.get_encoder_features(encoder_outputs) logits = [] prev_states = init_states - prev_context = torch.zeros((batch_size, 1, hidden_size), device=config.device) for i in range(max_len): y_i = trg_seq[:, i].unsqueeze(1) # [b, 1] embedded = self.embedding(y_i) # [b, 1, d] - lstm_inputs = self.reduce_layer(torch.cat([embedded, prev_context], dim=2)) - output, states = self.lstm(lstm_inputs, prev_states) + self.lstm.flatten_parameters() + output, states = self.lstm(embedded, prev_states) # encoder-decoder attention context, energy = self.attention(output, memories, encoder_mask) concat_input = torch.cat((output, context), dim=2).squeeze(dim=1) @@ -142,19 +160,17 @@ def forward(self, trg_seq, ext_src_seq, init_states, encoder_outputs, encoder_ma logits.append(logit) # update prev state and context prev_states = states - prev_context = context logits = torch.stack(logits, dim=1) # [b, t, |V|] return logits - def decode(self, y, ext_x, prev_states, prev_context, encoder_features, encoder_mask): + def decode(self, y, ext_x, prev_states, encoder_features, encoder_mask): # forward one step lstm # y : [b] embedded = self.embedding(y.unsqueeze(1)) - lstm_inputs = self.reduce_layer(torch.cat([embedded, prev_context], dim=2)) - output, states = self.lstm(lstm_inputs, prev_states) + output, states = self.lstm(embedded, prev_states) context, energy = self.attention(output, encoder_features, encoder_mask) concat_input = torch.cat((output, context), dim=2).squeeze(dim=1) logit_input = torch.tanh(self.concat_layer(concat_input)) @@ -169,22 +185,91 @@ def decode(self, y, ext_x, prev_states, prev_context, encoder_features, encoder_ out, _ = scatter_max(energy, ext_x, out=out) out = out.masked_fill(out == -INF, 0) logit = extended_logit + out - logit = logit.masked_fill(logit == -INF, 0) + logit = logit.masked_fill(logit == 0, -INF) # forcing UNK prob 0 - logit[:, UNK_ID] = -INF + # logit[:, UNK_ID] = -INF + + return logit, states + + +class PointerDecoder(nn.Module): + def __init__(self, hidden_size, num_layers, dropout): + super(PointerDecoder, self).__init__() + + self.go_embedding = nn.Parameter(torch.randn(1, hidden_size)) + self.concat_layer = nn.Linear(2 * hidden_size, hidden_size, bias=False) + self.attention_layer = nn.Linear(hidden_size, 1, bias=False) + self.lstm = nn.LSTM(hidden_size, hidden_size, dropout=dropout, + batch_first=True, num_layers=num_layers) + + def forward(self, enc_outputs, init_states, start_positions): + """ + + :param enc_outputs: [b,t,d] hidden states of encoder + :param init_states: tuple of [2, b, d]last hidden and cell states of encoder + :param start_positions: [b] ground truth for start position + :return: list of start logits and end logits + """ + batch_size, nsteps, _ = enc_outputs.size() + enc_mask = (torch.sum(enc_outputs, dim=-1) == 0).byte() + # tile go embedding + inputs = self.go_embedding.unsqueeze(0).repeat([batch_size, 1, 1]) + states = init_states + logits = [] + for i in range(2): + self.lstm.flatten_parameters() + hidden, states = self.lstm(inputs, states) + logit = self.attention(enc_outputs, enc_mask, hidden) + logits.append(logit) + # teacher forcing + inputs = enc_outputs[torch.arange(batch_size), start_positions] + inputs = inputs.unsqueeze(dim=1) + return logits + + def attention(self, memories, mask, dec_hidden): + nsteps = memories.size(1) + tiled_hidden = dec_hidden.repeat([1, nsteps, 1]) + concat_input = torch.cat([memories, tiled_hidden], dim=2) + attn_features = torch.tanh(self.concat_layer(concat_input)) + energies = self.attention_layer(attn_features).squeeze(2) # [b,t,1] -> [b,t] + logit = energies.masked_fill(mask, -1e12) # [b, t] + return logit + + +class AnswerSelector(nn.Module): + def __init__(self, embedding=None, model_path=None): + super(AnswerSelector, self).__init__() + self.encoder = Encoder(embedding, + config.hidden_size, + config.num_layers, + config.dropout, + use_tag=False) + self.decoder = PointerDecoder(2 * config.hidden_size, + config.num_layers, + config.dropout) + if model_path is not None: + ckpt = torch.load(model_path) + self.encoder.load_state_dict(ckpt["encoder_state_dict"]) + self.decoder.load_state_dict(ckpt["decoder_state_dict"]) - return logit, states, context + def forward(self, src_seqs, src_len, start_positions): + tag_seq = None + enc_outputs, enc_states = self.encoder(src_seqs, src_len, tag_seq) + logits = self.decoder(enc_outputs, enc_states, start_positions) + + return logits class Seq2seq(nn.Module): - def __init__(self, embedding=None, is_eval=False, model_path=None): + def __init__(self, embedding=None, use_tag=False, model_path=None): super(Seq2seq, self).__init__() - encoder = Encoder(embedding, config.vocab_size, - config.embedding_size, config.hidden_size, + encoder = Encoder(embedding, + config.hidden_size, config.num_layers, - config.dropout) - decoder = Decoder(embedding, config.vocab_size, - config.embedding_size, 2 * config.hidden_size, + config.dropout, + use_tag) + decoder = Decoder(embedding, + 2 * config.hidden_size, config.num_layers, config.dropout) @@ -196,9 +281,6 @@ def __init__(self, embedding=None, is_eval=False, model_path=None): self.encoder = encoder self.decoder = decoder - if is_eval: - self.eval_mode() - if model_path is not None: ckpt = torch.load(model_path) self.encoder.load_state_dict(ckpt["encoder_state_dict"]) @@ -211,3 +293,84 @@ def eval_mode(self): def train_mode(self): self.encoder = self.encoder.train() self.decoder = self.decoder.train() + + +class DualNet(nn.Module): + def __init__(self, c2q_model_path, c2a_model_path): + super(DualNet, self).__init__() + + self.qa_model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") + embedding = self.qa_model.bert.embeddings.word_embeddings.weight + + self.ca2q_model = Seq2seq(embedding, use_tag=True) + self.c_encoder = Encoder(embedding, config.hidden_size, + config.num_layers, config.dropout, + use_tag=False) + self.c2q_decoder = Decoder(embedding, 2 * config.hidden_size, + config.num_layers, config.dropout) + self.c2a_decoder = PointerDecoder(2 * config.hidden_size, + config.num_layers, + config.dropout) + + def forward(self, batch_data): + # sorting for using packed_sequence and padded_pack_sequence + c_ids, c_lens, tag_ids, q_ids, \ + input_ids, input_mask, segment_ids, \ + start_positions, end_positions, \ + noq_start_positions, noq_end_positions = batch_data + + # sorting for using packed_padded_sequence and pad_packed_sequence + c_lens, idx = torch.sort(c_lens, descending=True) + c_ids = c_ids[idx] + tag_ids = tag_ids[idx] + q_ids = q_ids[idx] + input_ids = input_ids[idx] + input_mask = input_mask[idx] + segment_ids = segment_ids[idx] + start_positions = start_positions[idx] + end_positions = end_positions[idx] + noq_start_positions = noq_start_positions[idx] + noq_end_positions = noq_end_positions[idx] + + # QA loss + qa_loss = self.qa_model(input_ids, segment_ids, input_mask, start_positions, end_positions) + + # QG without answer loss + enc_outputs, enc_states = self.c_encoder(c_ids, c_lens, None) + sos_q_ids = q_ids[:, :-1] + eos_q_ids = q_ids[:, 1:] + q_logits = self.c2q_decoder(sos_q_ids, c_ids, enc_states, enc_outputs) + batch_size, nsteps, _ = q_logits.size() + criterion = nn.CrossEntropyLoss(ignore_index=0) + preds = q_logits.view(batch_size * nsteps, -1) + targets = eos_q_ids.contiguous().view(-1) + c2q_loss = criterion(preds, targets) + + # QG with answer loss + enc_outputs, enc_states = self.ca2q_model.encoder(c_ids, c_lens, tag_ids) + q_logits = self.ca2q_model.decoder(sos_q_ids, c_ids, enc_states, enc_outputs) + preds = q_logits.view(batch_size * nsteps, -1) + ca2q_loss = criterion(preds, targets) + + # answer span without question + enc_outputs, states = self.c_encoder(c_ids, c_lens, None) + logits = self.c2a_decoder(enc_outputs, states, start_positions) + start_logits, end_logits = logits + + ignored_index = start_logits.size(1) + start_logits.clamp_(0, ignored_index) + end_logits.clamp_(0, ignored_index) + criterion = nn.CrossEntropyLoss(ignore_index=ignored_index) + + start_loss = criterion(start_logits, noq_start_positions) + end_loss = criterion(end_logits, noq_end_positions) + c2a_loss = (start_loss + end_loss) / 2 + + # regularization loss + reg_loss = (qa_loss + c2q_loss - ca2q_loss - c2a_loss) ** 2 + + qa_loss = (qa_loss + config.dual_lambda * reg_loss) + c2q_loss = (c2q_loss + config.dual_lambda * reg_loss) + ca2q_loss = (ca2q_loss + config.dual_lambda * reg_loss) + c2a_loss = (c2a_loss + config.dual_lambda * reg_loss) + return qa_loss, c2q_loss, ca2q_loss, c2a_loss diff --git a/trainer.py b/trainer.py index d0750fd..2dc107e 100644 --- a/trainer.py +++ b/trainer.py @@ -6,10 +6,13 @@ import torch import torch.nn as nn import torch.optim as optim +from torch.utils.data import TensorDataset, RandomSampler, DataLoader +from pytorch_pretrained_bert import BertTokenizer, BertAdam, BertForQuestionAnswering import config from data_utils import get_loader, eta, user_friendly_time, progress_bar, time_since -from model import Seq2seq +from model import Seq2seq, DualNet, AnswerSelector +from squad_utils import convert_examples_to_features, read_squad_examples class Trainer(object): @@ -41,13 +44,12 @@ def __init__(self, model_path=None): if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) - self.model = Seq2seq(embedding, model_path=model_path) + self.model = Seq2seq(embedding, config.use_tag, model_path=model_path) params = list(self.model.encoder.parameters()) \ + list(self.model.decoder.parameters()) self.lr = config.lr self.optim = optim.SGD(params, self.lr, momentum=0.8) - # self.optim = optim.Adam(params) self.criterion = nn.CrossEntropyLoss(ignore_index=0) def save_model(self, loss, epoch): @@ -82,8 +84,7 @@ def train(self): self.optim.zero_grad() batch_loss.backward() # gradient clipping - nn.utils.clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm) - nn.utils.clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm) + nn.utils.clip_grad_norm_(self.model.parameters(), config.max_grad_norm) self.optim.step() batch_loss = batch_loss.detach().item() msg = "{}/{} {} - ETA : {} - loss : {:.4f}" \ @@ -106,8 +107,6 @@ def step(self, train_data): src_seq, ext_src_seq, src_len, trg_seq, ext_trg_seq, trg_len, _ = train_data tag_seq = None src_len = torch.LongTensor(src_len) - enc_zeros = torch.zeros_like(src_seq) - enc_mask = torch.ByteTensor(src_seq == enc_zeros) if config.use_gpu: src_seq = src_seq.to(config.device) @@ -115,7 +114,6 @@ def step(self, train_data): src_len = src_len.to(config.device) trg_seq = trg_seq.to(config.device) ext_trg_seq = ext_trg_seq.to(config.device) - enc_mask = enc_mask.to(config.device) if config.use_tag: tag_seq = tag_seq.to(config.device) else: @@ -127,7 +125,7 @@ def step(self, train_data): if config.use_pointer: eos_trg = ext_trg_seq[:, 1:] - logits = self.model.decoder(sos_trg, ext_src_seq, enc_states, enc_outputs, enc_mask) + logits = self.model.decoder(sos_trg, ext_src_seq, enc_states, enc_outputs) batch_size, nsteps, _ = logits.size() preds = logits.view(batch_size * nsteps, -1) targets = eos_trg.contiguous().view(-1) @@ -149,3 +147,378 @@ def evaluate(self, msg): val_loss = np.mean(val_losses) return val_loss + + +class QGTrainer(object): + def __init__(self): + # load Bert Tokenizer and pre-trained word embedding + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + embedding = BertForQuestionAnswering.from_pretrained("bert-base-uncased") \ + .bert.embeddings.word_embeddings.weight + self.model = Seq2seq(embedding, use_tag=False) + + train_dir = os.path.join("./save", "c2q") + + self.train_loader = self.get_data_loader("./squad/train-v1.1.json") + self.dev_loader = self.get_data_loader("./squad/new_dev-v1.1.json") + + self.model_dir = os.path.join(train_dir, "train_%d" % int(time.strftime("%m%d%H%M%S"))) + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + + params = list(self.model.encoder.parameters()) \ + + list(self.model.decoder.parameters()) + + self.lr = 1.0 + self.optim = optim.Adam(params) + self.criterion = nn.CrossEntropyLoss(ignore_index=0) + + def get_data_loader(self, file): + train_examples = read_squad_examples(file, is_training=True, debug=config.debug) + train_features = convert_examples_to_features(train_examples, + tokenizer=self.tokenizer, + max_seq_length=config.max_seq_len, + max_query_length=config.max_query_len, + doc_stride=128, + is_training=True) + + all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) + all_c_lens = torch.sign(torch.sum(all_c_ids, 1)).long() + all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long) + all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long) + train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids, all_q_ids) + sampler = RandomSampler(train_data) + train_loader = DataLoader(train_data, sampler=sampler, batch_size=config.batch_size) + + return train_loader + + def save_model(self, loss, epoch): + state_dict = { + "epoch": epoch, + "current_loss": loss, + "encoder_state_dict": self.model.encoder.state_dict(), + "decoder_state_dict": self.model.decoder.state_dict() + } + loss = round(loss, 2) + model_save_path = os.path.join(self.model_dir, str(epoch) + "_" + str(loss)) + torch.save(state_dict, model_save_path) + + def train(self): + batch_num = len(self.train_loader) + self.model.train_mode() + best_loss = 1e10 + for epoch in range(1, config.num_epochs + 1): + print("epoch {}/{} :".format(epoch, config.num_epochs), end="\r") + start = time.time() + + for batch_idx, train_data in enumerate(self.train_loader, start=1): + batch_loss = self.step(train_data) + + self.optim.zero_grad() + batch_loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(self.model.parameters(), config.max_grad_norm) + self.optim.step() + batch_loss = batch_loss.detach().item() + msg = "{}/{} {} - ETA : {} - loss : {:.4f}" \ + .format(batch_idx, batch_num, progress_bar(batch_idx, batch_num), + eta(start, batch_idx, batch_num), batch_loss) + print(msg, end="\r") + + # compute validation loss for every epoch + val_loss = self.evaluate(msg) + if val_loss <= best_loss: + best_loss = val_loss + self.save_model(val_loss, epoch) + + print("Epoch {} took {} - final loss : {:.4f} - val loss :{:.4f}" + .format(epoch, user_friendly_time(time_since(start)), batch_loss, val_loss)) + + def step(self, train_data): + c_ids, c_lens, tag_ids, q_ids = train_data + # exclude unnecessary PAD tokens of c_ids and q_ids + max_c_len = torch.max(c_lens) + c_ids = c_ids[:, :max_c_len] + q_len = torch.sum(torch.sign(q_ids), 1).long() + max_q_len = torch.max(q_len) + q_ids = q_ids[:, :max_q_len] + + # sort data by the length of input seq and allocate tensors to gpu device + c_lens, idx = torch.sort(c_lens) + c_ids = c_ids[idx].to(config.device) + c_lens = c_lens.to(config.device) + q_ids = q_ids[idx].to(config.device) + + if config.use_tag: + tag_ids = tag_ids[idx].to(config.device) # we do not use tag seqs + else: + tag_ids = None + # forward Encoder + enc_outputs, enc_states = self.model.encoder(c_ids, c_lens, tag_ids) + + sos_trg = q_ids[:, :-1] # exclude END token + eos_trg = q_ids[:, 1:] # exclude START token + # forward decoder + logits = self.model.decoder(sos_trg, c_ids, enc_states, enc_outputs) + # compute loss + batch_size, nsteps, _ = logits.size() + preds = logits.view(batch_size * nsteps, -1) + targets = eos_trg.contiguous().view(-1) + loss = self.criterion(preds, targets) + return loss + + def evaluate(self, msg): + self.model.eval_mode() + num_val_batches = len(self.dev_loader) + val_losses = [] + for i, val_data in enumerate(self.dev_loader, start=1): + with torch.no_grad(): + val_batch_loss = self.step(val_data) + val_losses.append(val_batch_loss.item()) + msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches) + print(msg2, end="\r") + # go back to train mode + self.model.train_mode() + val_loss = np.mean(val_losses) + + return val_loss + + +class C2ATrainer(object): + def __init__(self): + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + embedding = BertForQuestionAnswering.from_pretrained("bert-base-uncased") \ + .bert.embeddings.word_embeddings.weight + + # instantiate class and allocate it to gpu device + self.model = AnswerSelector(embedding).to(config.device) + train_dir = os.path.join("./save", "c2a") + self.train_loader = self.get_data_loader("./squad/train-v1.1.json") + self.dev_loader = self.get_data_loader("./squad/new_dev-v1.1.json") + self.model_dir = os.path.join(train_dir, "train_%d" % int(time.strftime("%m%d%H%M%S"))) + if not os.path.exists(self.model_dir): + os.makedirs(self.model_dir) + + params = self.model.parameters() + self.optim = optim.Adam(params) + + def get_data_loader(self, file): + train_examples = read_squad_examples(file, is_training=True, debug=config.debug) + train_features = convert_examples_to_features(train_examples, + tokenizer=self.tokenizer, + max_seq_length=config.max_seq_len, + max_query_length=config.max_query_len, + doc_stride=128, + is_training=True) + + all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) + all_c_lens = torch.sum(torch.sign(all_c_ids), 1).long() + all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long) + all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long) + + train_data = TensorDataset(all_c_ids, all_c_lens, all_noq_start_positions, all_noq_end_positions) + + train_loader = DataLoader(train_data, shuffle=True, batch_size=config.batch_size) + + return train_loader + + def save_model(self, loss, epoch): + loss = round(loss, 2) + model_save_path = os.path.join(self.model_dir, str(epoch) + "_" + str(loss)) + state_dict = self.model.state_dict() + torch.save(state_dict, model_save_path) + + def train(self): + batch_num = len(self.train_loader) + best_loss = 1e10 + for epoch in range(1, config.num_epochs + 1): + print("epoch {}/{} :".format(epoch, config.num_epochs), end="\r") + start = time.time() + + for batch_idx, train_data in enumerate(self.train_loader, start=1): + batch_loss = self.step(train_data) + + self.optim.zero_grad() + batch_loss.backward() + # gradient clipping + nn.utils.clip_grad_norm_(self.model.parameters(), config.max_grad_norm) + self.optim.step() + batch_loss = batch_loss.detach().item() + msg = "{}/{} {} - ETA : {} - loss : {:.4f}" \ + .format(batch_idx, batch_num, progress_bar(batch_idx, batch_num), + eta(start, batch_idx, batch_num), batch_loss) + print(msg, end="\r") + + # compute validation loss for every epoch + val_loss = self.evaluate(msg) + if val_loss <= best_loss: + best_loss = val_loss + self.save_model(val_loss, epoch) + + print("Epoch {} took {} - final loss : {:.4f} - val loss :{:.4f}" + .format(epoch, user_friendly_time(time_since(start)), batch_loss, val_loss)) + + def step(self, train_data): + c_ids, c_lens, start_positions, end_positions = train_data + + # sort data allocate tensors to gpu device + c_lens, idx = torch.sort(c_lens, descending=True) + c_ids = c_ids[idx].to(config.device) + c_lens = c_lens.to(config.device) + start_positions = start_positions[idx].to(config.device) + end_positions = end_positions[idx].to(config.device) + # forward pass + start_logits, end_logits = self.model(c_ids, c_lens, start_positions) + # compute loss + + ignored_index = start_logits.size(1) + start_logits.clamp_(0, ignored_index) + end_logits.clamp_(0, ignored_index) + criterion = nn.CrossEntropyLoss(ignore_index=ignored_index) + + start_loss = criterion(start_logits, start_positions) + end_loss = criterion(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + return loss + + def evaluate(self, msg): + self.model.eval() + num_val_batches = len(self.dev_loader) + val_losses = [] + for i, val_data in enumerate(self.dev_loader, start=1): + with torch.no_grad(): + val_batch_loss = self.step(val_data) + val_losses.append(val_batch_loss.item()) + msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches) + print(msg2, end="\r") + # go back to train mode + self.model.train() + val_loss = np.mean(val_losses) + + return val_loss + + +class DualTrainer(object): + def __init__(self, c2q_model_path, c2a_model_path): + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + self.model = DualNet(c2q_model_path, c2a_model_path) + self.save_dir = os.path.join("./save", "dual") + self.num_train = None + # read data-set and prepare iterator + self.train_loader = self.get_data_loader("./squad/train-v1.1.json") + self.dev_loader = self.get_data_loader("./squad/new_dev-v1.1.json") + + num_train_optimization_steps = int( + self.num_train / config.batch_size / config.gradient_accumulation_steps) * config.num_epochs + # optimizer + param_optimizer = list(self.model.qa_model.named_parameters()) + # hack to remove pooler, which is not used + # thus it produce None grad that break apex + param_optimizer = [n for n in param_optimizer if "pooler" not in n[0]] + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + self.qa_opt = BertAdam(optimizer_grouped_parameters, + lr=config.qa_lr, + warmup=config.warmup_proportion, + t_total=num_train_optimization_steps) + + params = list(self.model.c_encoder.parameters()) \ + + list(self.model.ca2q_model.parameters()) \ + + list(self.model.c2a_decoder.parameters()) \ + + list(self.model.c2q_decoder.parameters()) + self.general_opt = optim.Adam(params, lr=config.lr) + + # assign model to device and wrap it with Dataparallel + torch.cuda.set_device(0) + self.model.cuda() + self.model = nn.DataParallel(self.model) + + self.model.train() + + def get_data_loader(self, file): + train_examples = read_squad_examples(file, is_training=True, debug=config.debug) + train_features = convert_examples_to_features(train_examples, + tokenizer=self.tokenizer, + max_seq_length=config.max_seq_len, + max_query_length=config.max_query_len, + doc_stride=128, + is_training=True) + self.num_train = len(train_examples) + all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) + all_c_lens = torch.sign(torch.sum(all_c_ids, 1)).long() + all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long) + all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long) + all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) + all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) + all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long) + all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long) + + train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids, + all_q_ids, all_input_ids, all_input_mask, + all_segment_ids, all_start_positions, all_end_positions, + all_noq_start_positions, all_noq_end_positions) + + sampler = RandomSampler(train_data) + train_loader = DataLoader(train_data, sampler=sampler, batch_size=config.batch_size) + + return train_loader + + def save_model(self, accuracy, epoch): + acc = round(accuracy, 3) + dir_name = os.path.join(self.save_dir, "{:.3f}_{}".format(acc, epoch)) + # save bert model + model_to_save = self.model.module.qa_model if hasattr(self.model, "module") else self.model.qa_model + model_file = os.path.join(dir_name, "pytorch_model.bin") + config_file = os.path.join(dir_name, "bert_config.json") + + state_dict = model_to_save.state_dict() + torch.save(state_dict, model_file) + json_file = model_to_save.config.to_json_string(config_file) + with open(config_file, "w") as f: + f.write(json_file) + + def train(self): + global_step = 1 + device = torch.device("cuda") + batch_num = len(self.train_loader) + for epoch in range(1, config.num_epochs + 1): + start = time.time() + for batch_idx, batch in enumerate(self.train_loader, start=1): + batch = tuple(t.to(device) for t in batch) + qa_loss, c2q_loss, ca2q_loss, c2a_loss = self.model(batch) + # zero grad + self.qa_opt.zero_grad() + self.general_opt.zero_grad() + # mean() to average across multiple gpu and back-propagation + qa_loss = qa_loss.mean() + c2q_loss = c2q_loss.mean() + ca2q_loss = ca2q_loss.mean() + c2a_loss = c2a_loss.mean() + + qa_loss.backward(retain_graph=True) + c2q_loss.backward(retain_graph=True) + ca2q_loss.backward(retain_graph=True) + c2a_loss.backward() + + # clip gradient + nn.utils.clip_grad_norm_(self.model.module.c_encoder.parameters(), config.max_grad_norm) + nn.utils.clip_grad_norm_(self.model.module.c2q_decoder.parameters(), config.max_grad_norm) + nn.utils.clip_grad_norm_(self.model.module.c2a_decoder.parameters(), config.max_grad_norm) + nn.utils.clip_grad_norm_(self.model.module.ca2q_model.parameters(), config.max_grad_norm) + + # update params + self.qa_opt.step() + self.general_opt.step() + global_step += 1 + msg = "{}/{} {} - ETA : {} - qa_loss: {:.2f}, c2q_loss :{:.2f}, ca2q_loss :{:.2f}, c2a_loss:{:.2f}" \ + .format(batch_idx, batch_num, progress_bar(batch_idx, batch_num), + eta(start, batch_idx, batch_num), + qa_loss.item(), c2q_loss.item(), ca2q_loss.item(), c2a_loss.item()) + print(msg, end="\r") + print("----------------------------")