diff --git a/nmt.py b/nmt.py new file mode 100644 index 0000000..b57efc0 --- /dev/null +++ b/nmt.py @@ -0,0 +1,1462 @@ +from __future__ import print_function + +import re + +import torch +import torch.nn as nn +import torch.nn.utils +from torch.autograd import Variable +from torch import optim +from torch.nn import Parameter +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + +from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction +import time +import numpy as np +from collections import defaultdict, Counter, namedtuple +from itertools import chain, islice +import argparse, os, sys + +from util import read_corpus, data_iter, batch_slice, map_ent_label +from vocab import Vocab, VocabEntry +from process_samples import generate_hamming_distance_payoff_distribution +import math + +from rougescore import * +from nli_models.baseline_snli_one import * + +def init_config(): + parser = argparse.ArgumentParser() + parser.add_argument('--seed', default=5783287, type=int, help='random seed') + parser.add_argument('--cuda', action='store_true', default=False, help='use gpu') + parser.add_argument('--mode', choices=['train', 'raml_train', 'test', 'sample', 'prob', 'interactive'], + default='train', help='run mode') + parser.add_argument('--vocab', type=str, help='path of the serialized vocabulary') + parser.add_argument('--switch', default=10, type=int, help='multi-task switch') + parser.add_argument('--batch_size', default=64, type=int, help='batch size') + parser.add_argument('--beam_size', default=10, type=int, help='beam size for beam search') + parser.add_argument('--sample_size', default=5, type=int, help='sample size') + parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings') + parser.add_argument('--hidden_size', default=512, type=int, help='size of LSTM hidden states') + parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate') + + parser.add_argument('--train_src', type=str, help='path to the training source file') + parser.add_argument('--train_tgt', type=str, help='path to the training target file') + parser.add_argument('--train_ent_x', type=str, help='path to the training entailment x file') + parser.add_argument('--train_ent_y', type=str, help='path to the training entailment y file') + parser.add_argument('--train_ent_label', type=str, help='path to the training entailment label file') + parser.add_argument('--dev_src', type=str, help='path to the dev source file') + parser.add_argument('--dev_tgt', type=str, help='path to the dev target file') + parser.add_argument('--test_src', type=str, help='path to the test source file') + parser.add_argument('--test_tgt', type=str, help='path to the test target file') + + parser.add_argument('--decode_max_time_step', default=20, type=int, help='maximum number of time steps used ' + 'in decoding and sampling') + + parser.add_argument('--valid_niter', default=2000, type=int, help='every n iterations to perform validation') + parser.add_argument('--valid_metric', default='rouge2_f', choices=['rouge2_f', 'bleu', 'ppl', 'word_acc', 'sent_acc'], help='metric used for validation') + parser.add_argument('--log_every', default=500, type=int, help='every n iterations to log training statistics') + parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model') + parser.add_argument('--save_to', default='model', type=str, help='save trained model to') + parser.add_argument('--save_model_after', default=0, help='save the model only after n validation iterations') + parser.add_argument('--save_to_file', default=None, type=str, help='if provided, save decoding results to file') + parser.add_argument('--save_nbest', default=False, action='store_true', help='save nbest decoding results') + parser.add_argument('--patience', default=30, type=int, help='training patience') + parser.add_argument('--drop_patience', default=12, type=int, help='training lr drop patience') + parser.add_argument('--uniform_init', default=0.1, type=float, help='if specified, use uniform initialization for all parameters') + parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients') + parser.add_argument('--max_niter', default=-1, type=int, help='maximum number of training iterations') + parser.add_argument('--lr', default=0.0005, type=float, help='learning rate') + parser.add_argument('--lr_decay', default=0.5, type=float, help='decay learning rate if the validation performance drops') + + # raml training + parser.add_argument('--debug', default=False, action='store_true') + parser.add_argument('--temp', default=0.85, type=float, help='temperature in reward distribution') + parser.add_argument('--raml_sample_mode', default='pre_sample', + choices=['pre_sample', 'hamming_distance', 'hamming_distance_impt_sample'], + help='sample mode when using RAML') + parser.add_argument('--raml_sample_file', type=str, help='path to the sampled targets') + parser.add_argument('--raml_bias_groundtruth', action='store_true', default=False, help='make sure ground truth y* is in samples') + parser.add_argument('--reward_type', default='entailment', type=str, choices=['rouge2_f', 'bleu', 'entailment']) + + parser.add_argument('--smooth_bleu', action='store_true', default=False, + help='smooth sentence level BLEU score.') + + #TODO: greedy sampling is still buggy! + parser.add_argument('--sample_method', default='random', choices=['random', 'greedy']) + + args = parser.parse_args() + + # seed the RNG + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + np.random.seed(args.seed * 13 / 7) + + return args + + +def input_transpose(sents, pad_token): + max_len = max(len(s) for s in sents) + batch_size = len(sents) + + sents_t = [] + masks = [] + for i in xrange(max_len): + sents_t.append([sents[k][i] if len(sents[k]) > i else pad_token for k in xrange(batch_size)]) + masks.append([1 if len(sents[k]) > i else 0 for k in xrange(batch_size)]) + + return sents_t, masks + + +def word2id(sents, vocab): + if type(sents[0]) == list: + return [[vocab[w] for w in s] for s in sents] + else: + return [vocab[w] for w in sents] + + +def tensor_transform(linear, X): + # X is a 3D tensor + return linear(X.contiguous().view(-1, X.size(2))).view(X.size(0), X.size(1), -1) + + +def sentence_entailment(src, targ, word_dic, input_encoder, inter_atten): + + if '' in src: + src.remove('') + if '' in targ: + targ.remove('') + src = [''] + src + targ = [''] + targ + src = [word_dic[w] for w in src if w in word_dic] + targ = [word_dic[w] for w in targ if w in word_dic] + source = Variable((torch.from_numpy(np.array(src, dtype=int)) - 1).cuda()) + target = Variable((torch.from_numpy(np.array(targ, dtype=int)) - 1).cuda()) + test_src_linear, test_tgt_linear=input_encoder(source, target) + log_prob=inter_atten(test_src_linear, test_tgt_linear) + scores = F.softmax(log_prob).data.cpu().numpy()[0] + reward = float(scores[0]) + + return reward + +class NMT(nn.Module): + def __init__(self, args, vocab): + super(NMT, self).__init__() + + self.args = args + + self.vocab = vocab + + self.src_embed = nn.Embedding(len(vocab.src), args.embed_size, padding_idx=vocab.src['']) + self.tgt_embed = nn.Embedding(len(vocab.tgt), args.embed_size, padding_idx=vocab.tgt['']) + + self.encoder_lstm = nn.LSTM(args.embed_size, args.hidden_size, bidirectional=True, dropout=args.dropout) + self.decoder_lstm = nn.LSTMCell(args.embed_size + args.hidden_size, args.hidden_size) + + # attention: dot product attention + # project source encoding to decoder rnn's h space + self.att_src_linear = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + + # transformation of decoder hidden states and context vectors before reading out target words + # this produces the `attentional vector` in (Luong et al., 2015) + self.att_vec_linear = nn.Linear(args.hidden_size * 2 + args.hidden_size, args.hidden_size, bias=False) + + # prediction layer of the target vocabulary + self.readout = nn.Linear(args.hidden_size, len(vocab.tgt), bias=False) + + # dropout layer + self.dropout = nn.Dropout(args.dropout) + self.decoder_cell_init = nn.Linear(args.hidden_size * 2, args.hidden_size) + self.decoder_hidden_init = nn.Linear(args.hidden_size * 2, args.hidden_size) + + self.classifier = nn.Sequential( + nn.Dropout(p=args.dropout), + nn.Linear(args.hidden_size * 8, 512), + nn.Tanh(), + nn.Dropout(p=args.dropout), + nn.Linear(512, 512), + nn.Tanh(), + nn.Dropout(p=args.dropout), + nn.Linear(512, 3), + ) + #self.sigmoid = nn.Sigmoid() + + def forward(self, task, src_sents, src_sents_len, tgt_words, tgt_sents_len): + if task == "summarization": + src_encodings, sv, init_ctx_vec = self.encode(src_sents, src_sents_len) + scores = self.decode(src_encodings, init_ctx_vec, tgt_words) + elif task == "entailment": + src_encodings, src_sentence_vectors, init_ctx_vec_src = self.encode(src_sents, src_sents_len) + tgt_sentence_vectors = self.encode_sort(tgt_words, tgt_sents_len) + features = torch.cat((src_sentence_vectors, tgt_sentence_vectors, torch.abs(src_sentence_vectors - tgt_sentence_vectors), \ + src_sentence_vectors * tgt_sentence_vectors), 1) + scores = self.classifier(features) + + return scores + + def encode_sort(self, src_sents, src_sents_len): + """ + :param src_sents: (src_sent_len, batch_size), sorted by the length of the source + :param src_sents_len: (src_sent_len) + """ + src_sents_len_sort, idx_sort = np.sort(src_sents_len)[::-1], np.argsort(src_sents_len)[::-1] + idx_sort = list(idx_sort) + src_sents_sort = src_sents.index_select(1, Variable(torch.cuda.LongTensor(idx_sort))) + #src_sents_sort = torch.index_select(src_sents, 1, Variable(torch.cuda.LongTensor(idx_sort))) + + # (src_sent_len, batch_size, embed_size) + src_word_embed = self.src_embed(src_sents_sort) + packed_src_embed = pack_padded_sequence(src_word_embed, src_sents_len_sort) + + # output: (src_sent_len, batch_size, hidden_size) + output, (last_state, last_cell) = self.encoder_lstm(packed_src_embed) + + sentence_vector = torch.cat([last_state[0], last_state[1]], 1) + idx_unsort = np.argsort(idx_sort) + sentence_vector = sentence_vector.index_select(0, Variable(torch.cuda.LongTensor(idx_unsort))) + + return sentence_vector + + def encode(self, src_sents, src_sents_len): + """ + :param src_sents: (src_sent_len, batch_size), sorted by the length of the source + :param src_sents_len: (src_sent_len) + """ + # (src_sent_len, batch_size, embed_size) + src_word_embed = self.src_embed(src_sents) + packed_src_embed = pack_padded_sequence(src_word_embed, src_sents_len) + + # output: (src_sent_len, batch_size, hidden_size) + output, (last_state, last_cell) = self.encoder_lstm(packed_src_embed) + output, _ = pad_packed_sequence(output) + #dec_init_state = F.tanh(self.decoder_init(torch.cat([last_state[0], last_state[1]], 1))) + dec_init_cell = F.tanh(self.decoder_cell_init(torch.cat([last_cell[0], last_cell[1]], 1))) + dec_init_state = F.tanh(self.decoder_hidden_init(torch.cat([last_state[0], last_state[1]], 1))) + + sentence_vector = torch.cat([last_state[0], last_state[1]], 1) + + return output, sentence_vector, (dec_init_state, dec_init_cell) + + def decode(self, src_encoding, dec_init_vec, tgt_sents): + """ + :param src_encoding: (src_sent_len, batch_size, hidden_size) + :param dec_init_vec: (batch_size, hidden_size) + :param tgt_sents: (tgt_sent_len, batch_size) + :return: + """ + init_state = dec_init_vec[0] + init_cell = dec_init_vec[1] + hidden = (init_state, init_cell) + + new_tensor = init_cell.data.new + batch_size = src_encoding.size(1) + + # (batch_size, src_sent_len, hidden_size * 2) + src_encoding = src_encoding.permute(1, 0, 2) + # (batch_size, src_sent_len, hidden_size) + src_encoding_att_linear = tensor_transform(self.att_src_linear, src_encoding) + # initialize attentional vector + att_tm1 = Variable(new_tensor(batch_size, self.args.hidden_size).zero_(), requires_grad=False) + + tgt_word_embed = self.tgt_embed(tgt_sents) + scores = [] + + # start from ``, until y_{T-1} + for y_tm1_embed in tgt_word_embed.split(split_size=1): + # input feeding: concate y_tm1 and previous attentional vector + x = torch.cat([y_tm1_embed.squeeze(0), att_tm1], 1) + + # h_t: (batch_size, hidden_size) + h_t, cell_t = self.decoder_lstm(x, hidden) + h_t = self.dropout(h_t) + + ctx_t, alpha_t = self.dot_prod_attention(h_t, src_encoding, src_encoding_att_linear) + + att_t = F.tanh(self.att_vec_linear(torch.cat([h_t, ctx_t], 1))) # E.q. (5) + att_t = self.dropout(att_t) + + score_t = self.readout(att_t) # E.q. (6) + scores.append(score_t) + + att_tm1 = att_t + hidden = h_t, cell_t + + scores = torch.stack(scores) + return scores + + def translate(self, src_sents, beam_size=None, to_word=True): + """ + perform beam search + TODO: batched beam search + """ + if not type(src_sents[0]) == list: + src_sents = [src_sents] + if not beam_size: + beam_size = args.beam_size + + src_sents_var = to_input_variable(src_sents, self.vocab.src, cuda=args.cuda, is_test=True) + + src_encoding, sv, dec_init_vec = self.encode(src_sents_var, [len(src_sents[0])]) + src_encoding_att_linear = tensor_transform(self.att_src_linear, src_encoding) + + init_state = dec_init_vec[0] + init_cell = dec_init_vec[1] + hidden = (init_state, init_cell) + + att_tm1 = Variable(torch.zeros(1, self.args.hidden_size), volatile=True) + hyp_scores = Variable(torch.zeros(1), volatile=True) + if args.cuda: + att_tm1 = att_tm1.cuda() + hyp_scores = hyp_scores.cuda() + + eos_id = self.vocab.tgt[''] + bos_id = self.vocab.tgt[''] + tgt_vocab_size = len(self.vocab.tgt) + + hypotheses = [[bos_id]] + completed_hypotheses = [] + completed_hypothesis_scores = [] + + t = 0 + while len(completed_hypotheses) < beam_size and t < args.decode_max_time_step: + t += 1 + hyp_num = len(hypotheses) + + expanded_src_encoding = src_encoding.expand(src_encoding.size(0), hyp_num, src_encoding.size(2)) + expanded_src_encoding_att_linear = src_encoding_att_linear.expand(src_encoding_att_linear.size(0), hyp_num, src_encoding_att_linear.size(2)) + + y_tm1 = Variable(torch.LongTensor([hyp[-1] for hyp in hypotheses]), volatile=True) + if args.cuda: + y_tm1 = y_tm1.cuda() + + y_tm1_embed = self.tgt_embed(y_tm1) + + x = torch.cat([y_tm1_embed, att_tm1], 1) + + # h_t: (hyp_num, hidden_size) + h_t, cell_t = self.decoder_lstm(x, hidden) + h_t = self.dropout(h_t) + + ctx_t, alpha_t = self.dot_prod_attention(h_t, expanded_src_encoding.permute(1, 0, 2), expanded_src_encoding_att_linear.permute(1, 0, 2)) + + att_t = F.tanh(self.att_vec_linear(torch.cat([h_t, ctx_t], 1))) + att_t = self.dropout(att_t) + + score_t = self.readout(att_t) + p_t = F.log_softmax(score_t) + + live_hyp_num = beam_size - len(completed_hypotheses) + new_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(p_t) + p_t).view(-1) + top_new_hyp_scores, top_new_hyp_pos = torch.topk(new_hyp_scores, k=live_hyp_num) + prev_hyp_ids = top_new_hyp_pos / tgt_vocab_size + word_ids = top_new_hyp_pos % tgt_vocab_size + # new_hyp_scores = new_hyp_scores[top_new_hyp_pos.data] + + new_hypotheses = [] + + live_hyp_ids = [] + new_hyp_scores = [] + for prev_hyp_id, word_id, new_hyp_score in zip(prev_hyp_ids.cpu().data, word_ids.cpu().data, top_new_hyp_scores.cpu().data): + hyp_tgt_words = hypotheses[prev_hyp_id] + [word_id] + if word_id == eos_id: + completed_hypotheses.append(hyp_tgt_words) + completed_hypothesis_scores.append(new_hyp_score) + else: + new_hypotheses.append(hyp_tgt_words) + live_hyp_ids.append(prev_hyp_id) + new_hyp_scores.append(new_hyp_score) + + if len(completed_hypotheses) == beam_size: + break + + live_hyp_ids = torch.LongTensor(live_hyp_ids) + if args.cuda: + live_hyp_ids = live_hyp_ids.cuda() + + hidden = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) + att_tm1 = att_t[live_hyp_ids] + + hyp_scores = Variable(torch.FloatTensor(new_hyp_scores), volatile=True) # new_hyp_scores[live_hyp_ids] + if args.cuda: + hyp_scores = hyp_scores.cuda() + hypotheses = new_hypotheses + + if len(completed_hypotheses) == 0: + completed_hypotheses = [hypotheses[0]] + completed_hypothesis_scores = [0.0] + + if to_word: + for i, hyp in enumerate(completed_hypotheses): + completed_hypotheses[i] = [self.vocab.tgt.id2word[w] for w in hyp] + + ranked_hypotheses = sorted(zip(completed_hypotheses, completed_hypothesis_scores), key=lambda x: x[1], reverse=True) + + return [hyp for hyp, score in ranked_hypotheses] + + def sample(self, src_sents, sample_size=None, to_word=False): + if not type(src_sents[0]) == list: + src_sents = [src_sents] + if not sample_size: + sample_size = args.sample_size + + src_sents_num = len(src_sents) + batch_size = src_sents_num * sample_size + + src_sents_var = to_input_variable(src_sents, self.vocab.src, cuda=args.cuda, is_test=True) + src_encoding, sv, (dec_init_state, dec_init_cell) = self.encode(src_sents_var, [len(s) for s in src_sents]) + + dec_init_state = dec_init_state.repeat(sample_size, 1) + dec_init_cell = dec_init_cell.repeat(sample_size, 1) + hidden = (dec_init_state, dec_init_cell) + + src_encoding = src_encoding.repeat(1, sample_size, 1) + src_encoding_att_linear = tensor_transform(self.att_src_linear, src_encoding) + src_encoding = src_encoding.permute(1, 0, 2) + src_encoding_att_linear = src_encoding_att_linear.permute(1, 0, 2) + + new_tensor = dec_init_state.data.new + att_tm1 = Variable(new_tensor(batch_size, self.args.hidden_size).zero_(), volatile=True) + y_0 = Variable(torch.LongTensor([self.vocab.tgt[''] for _ in xrange(batch_size)]), volatile=True) + + eos = self.vocab.tgt[''] + # eos_batch = torch.LongTensor([eos] * batch_size) + sample_ends = torch.ByteTensor([0] * batch_size) + all_ones = torch.ByteTensor([1] * batch_size) + if args.cuda: + y_0 = y_0.cuda() + sample_ends = sample_ends.cuda() + all_ones = all_ones.cuda() + + samples = [y_0] + + t = 0 + while t < args.decode_max_time_step: + t += 1 + + # (sample_size) + y_tm1 = samples[-1] + + y_tm1_embed = self.tgt_embed(y_tm1) + + x = torch.cat([y_tm1_embed, att_tm1], 1) + + # h_t: (batch_size, hidden_size) + h_t, cell_t = self.decoder_lstm(x, hidden) + h_t = self.dropout(h_t) + + ctx_t, alpha_t = self.dot_prod_attention(h_t, src_encoding, src_encoding_att_linear) + + att_t = F.tanh(self.att_vec_linear(torch.cat([h_t, ctx_t], 1))) # E.q. (5) + att_t = self.dropout(att_t) + + score_t = self.readout(att_t) # E.q. (6) + p_t = F.softmax(score_t) + + if args.sample_method == 'random': + y_t = torch.multinomial(p_t, num_samples=1).squeeze(1) + elif args.sample_method == 'greedy': + _, y_t = torch.topk(p_t, k=1, dim=1) + y_t = y_t.squeeze(1) + + samples.append(y_t) + + sample_ends |= torch.eq(y_t, eos).byte().data + if torch.equal(sample_ends, all_ones): + break + + # if torch.equal(y_t.data, eos_batch): + # break + + att_tm1 = att_t + hidden = h_t, cell_t + + # post-processing + completed_samples = [list([list() for _ in xrange(sample_size)]) for _ in xrange(src_sents_num)] + for y_t in samples: + for i, sampled_word in enumerate(y_t.cpu().data): + src_sent_id = i % src_sents_num + sample_id = i / src_sents_num + if len(completed_samples[src_sent_id][sample_id]) == 0 or completed_samples[src_sent_id][sample_id][-1] != eos: + completed_samples[src_sent_id][sample_id].append(sampled_word) + + if to_word: + for i, src_sent_samples in enumerate(completed_samples): + completed_samples[i] = word2id(src_sent_samples, self.vocab.tgt.id2word) + + return completed_samples + + def attention(self, h_t, src_encoding, src_linear_for_att): + # (1, batch_size, attention_size) + (src_sent_len, batch_size, attention_size) => + # (src_sent_len, batch_size, attention_size) + att_hidden = F.tanh(self.att_h_linear(h_t).unsqueeze(0).expand_as(src_linear_for_att) + src_linear_for_att) + + # (batch_size, src_sent_len) + att_weights = F.softmax(tensor_transform(self.att_vec_linear, att_hidden).squeeze(2).permute(1, 0)) + + # (batch_size, hidden_size * 2) + ctx_vec = torch.bmm(src_encoding.permute(1, 2, 0), att_weights.unsqueeze(2)).squeeze(2) + + return ctx_vec, att_weights + + def dot_prod_attention(self, h_t, src_encoding, src_encoding_att_linear, mask=None): + """ + :param h_t: (batch_size, hidden_size) + :param src_encoding: (batch_size, src_sent_len, hidden_size * 2) + :param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size) + :param mask: (batch_size, src_sent_len) + """ + # (batch_size, src_sent_len) + att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2) + if mask: + att_weight.data.masked_fill_(mask, -float('inf')) + att_weight = F.softmax(att_weight) + + att_view = (att_weight.size(0), 1, att_weight.size(1)) + # (batch_size, hidden_size) + ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1) + + return ctx_vec, att_weight + + def save(self, path): + print('save parameters to [%s]' % path, file=sys.stderr) + params = { + 'args': self.args, + 'vocab': self.vocab, + 'state_dict': self.state_dict() + } + torch.save(params, path) + + +def to_input_variable(sents, vocab, cuda=False, is_test=False): + """ + return a tensor of shape (src_sent_len, batch_size) + """ + + word_ids = word2id(sents, vocab) + sents_t, masks = input_transpose(word_ids, vocab['']) + + sents_var = Variable(torch.LongTensor(sents_t), volatile=is_test, requires_grad=False) + if cuda: + sents_var = sents_var.cuda() + + return sents_var + + +def evaluate_loss(model, data, crit): + model.eval() + cum_loss = 0. + cum_tgt_words = 0. + for src_sents, tgt_sents in data_iter(data, batch_size=args.batch_size, shuffle=False): + pred_tgt_word_num = sum(len(s[1:]) for s in tgt_sents) # omitting leading `` + src_sents_len = [len(s) for s in src_sents] + + src_sents_var = to_input_variable(src_sents, model.vocab.src, cuda=args.cuda, is_test=True) + tgt_sents_var = to_input_variable(tgt_sents, model.vocab.tgt, cuda=args.cuda, is_test=True) + + # (tgt_sent_len, batch_size, tgt_vocab_size) + scores = model('summarization', src_sents_var, src_sents_len, tgt_sents_var[:-1], src_sents_len) + loss = crit(scores.view(-1, scores.size(2)), tgt_sents_var[1:].view(-1)) + + cum_loss += loss.data[0] + cum_tgt_words += pred_tgt_word_num + + loss = cum_loss / cum_tgt_words + return loss + +def init_training(args): + vocab = torch.load(args.vocab) + + model = NMT(args, vocab) + model.train() + + if args.uniform_init: + print('uniformly initialize parameters [-%f, +%f]' % (args.uniform_init, args.uniform_init), file=sys.stderr) + for p in model.parameters(): + p.data.uniform_(-args.uniform_init, args.uniform_init) + + vocab_mask = torch.ones(len(vocab.tgt)) + vocab_mask[vocab.tgt['']] = 0 + nll_loss = nn.NLLLoss(weight=vocab_mask, size_average=False) + cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, size_average=False) + + weight = torch.FloatTensor(3).fill_(1) + loss_ent = nn.CrossEntropyLoss(weight=weight, size_average=False) + + if args.cuda: + model = model.cuda() + nll_loss = nll_loss.cuda() + cross_entropy_loss = cross_entropy_loss.cuda() + loss_ent = loss_ent.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + return vocab, model, optimizer, nll_loss, cross_entropy_loss, loss_ent + +def cont_training_halve_raml(args, lr_now): + load_model_dir = args.save_to + '.bin' + print('load model from [%s]' % load_model_dir, file=sys.stderr) + params = torch.load(load_model_dir, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + state_dict = params['state_dict'] + + model = NMT(args, vocab) + model.train() + model.load_state_dict(state_dict) + + vocab_mask = torch.ones(len(vocab.tgt)) + vocab_mask[vocab.tgt['']] = 0 + nll_loss = nn.NLLLoss(weight=vocab_mask, size_average=False) + cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, size_average=False) + + if args.cuda: + model = model.cuda() + nll_loss = nll_loss.cuda() + cross_entropy_loss = cross_entropy_loss.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=lr_now) + + return vocab, model, optimizer, nll_loss, cross_entropy_loss + +def cont_training_halve_ent(args, lr_now): + load_model_dir = args.save_to + '.bin' + print('load model from [%s]' % load_model_dir, file=sys.stderr) + params = torch.load(load_model_dir, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + state_dict = params['state_dict'] + + model = NMT(args, vocab) + model.train() + model.load_state_dict(state_dict) + + vocab_mask = torch.ones(len(vocab.tgt)) + vocab_mask[vocab.tgt['']] = 0 + nll_loss = nn.NLLLoss(weight=vocab_mask, size_average=False) + cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, size_average=False) + + weight = torch.FloatTensor(3).fill_(1) + loss_ent = nn.CrossEntropyLoss(weight=weight, size_average=False) + + if args.cuda: + model = model.cuda() + nll_loss = nll_loss.cuda() + cross_entropy_loss = cross_entropy_loss.cuda() + loss_ent = loss_ent.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=lr_now) + + return vocab, model, optimizer, nll_loss, cross_entropy_loss, loss_ent + + +def cont_training_raml(args): + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + state_dict = params['state_dict'] + + model = NMT(args, vocab) + model.train() + model.load_state_dict(state_dict, strict=False) + + vocab_mask = torch.ones(len(vocab.tgt)) + vocab_mask[vocab.tgt['']] = 0 + nll_loss = nn.NLLLoss(weight=vocab_mask, size_average=False) + cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, size_average=False) + + if args.cuda: + model = model.cuda() + nll_loss = nll_loss.cuda() + cross_entropy_loss = cross_entropy_loss.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + return vocab, model, optimizer, nll_loss, cross_entropy_loss + +def cont_training(args): + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + state_dict = params['state_dict'] + + model = NMT(args, vocab) + model.train() + model.load_state_dict(state_dict, strict=False) + + vocab_mask = torch.ones(len(vocab.tgt)) + vocab_mask[vocab.tgt['']] = 0 + nll_loss = nn.NLLLoss(weight=vocab_mask, size_average=False) + cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask, size_average=False) + + weight = torch.FloatTensor(3).fill_(1) + loss_ent = nn.CrossEntropyLoss(weight=weight, size_average=False) + + if args.cuda: + model = model.cuda() + nll_loss = nll_loss.cuda() + cross_entropy_loss = cross_entropy_loss.cuda() + loss_ent = loss_ent.cuda() + + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + return vocab, model, optimizer, nll_loss, cross_entropy_loss, loss_ent + + +def train(args): + train_data_src = read_corpus(args.train_src, source='src') + train_data_tgt = read_corpus(args.train_tgt, source='tgt') + train_data_ent_x = read_corpus(args.train_ent_x, source='src') + train_data_ent_y = read_corpus(args.train_ent_y, source='src') + train_data_ent_label = read_corpus(args.train_ent_label, source='src') + train_data_ent_label = map_ent_label(train_data_ent_label) + + dev_data_src = read_corpus(args.dev_src, source='src') + dev_data_tgt = read_corpus(args.dev_tgt, source='tgt') + + train_data = zip(train_data_src, train_data_tgt) + train_data_ent = zip(train_data_ent_x, train_data_ent_y, train_data_ent_label) + + dev_data = zip(dev_data_src, dev_data_tgt) + + if not args.load_model: + vocab, model , optimizer, nll_loss, cross_entropy_loss, loss_ent = init_training(args) + else: + vocab, model, optimizer, nll_loss, cross_entropy_loss, loss_ent = cont_training(args) + + train_iter = patience = drop_patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0 + cum_examples = cum_batches = report_examples = valid_num = best_model_iter = 0 + train_multi_task_iter = 0 + train_ent_iter = 0 + hist_valid_scores = [] + train_time = begin_time = time.time() + print('begin Maximum Likelihood training') + + sum_data_generator = data_iter(train_data, batch_size=args.batch_size) + ent_data_generator = data_iter(train_data_ent, batch_size=args.batch_size) + + while True: + train_multi_task_iter += 1 + for train_iter in range((train_multi_task_iter - 1) * 100, train_multi_task_iter * 100): + try: + src_sents, tgt_sents = next(sum_data_generator) + except: + sum_data_generator = data_iter(train_data, batch_size=args.batch_size) + src_sents, tgt_sents = next(sum_data_generator) + src_sents_var = to_input_variable(src_sents, vocab.src, cuda=args.cuda) + tgt_sents_var = to_input_variable(tgt_sents, vocab.tgt, cuda=args.cuda) + + batch_size = len(src_sents) + src_sents_len = [len(s) for s in src_sents] + pred_tgt_word_num = sum(len(s[1:]) for s in tgt_sents) # omitting leading `` + + optimizer.zero_grad() + + # (tgt_sent_len, batch_size, tgt_vocab_size) + scores = model('summarization', src_sents_var, src_sents_len, tgt_sents_var[:-1], src_sents_len) + + word_loss = cross_entropy_loss(scores.view(-1, scores.size(2)), tgt_sents_var[1:].view(-1)) + loss = word_loss / batch_size + word_loss_val = word_loss.data[0] + loss_val = loss.data[0] + + loss.backward() + # clip gradient + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) + optimizer.step() + + report_loss += word_loss_val + cum_loss += word_loss_val + report_tgt_words += pred_tgt_word_num + cum_tgt_words += pred_tgt_word_num + report_examples += batch_size + cum_examples += batch_size + cum_batches += batch_size + + if train_iter > 0 and train_iter % args.log_every == 0: + print('iter %d, avg. loss %.2f, avg. ppl %.2f ' \ + 'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' % (train_iter, + report_loss / report_examples, + np.exp(report_loss / report_tgt_words), + cum_examples, + report_tgt_words / (time.time() - train_time), + time.time() - begin_time), file=sys.stderr) + + train_time = time.time() + report_loss = report_tgt_words = report_examples = 0. + + # perform validation + if train_iter > 0 and train_iter % args.valid_niter == 0: + print('iter %d, cum. loss %.2f, cum. ppl %.2f cum. examples %d' % (train_iter, + cum_loss / cum_batches, + np.exp(cum_loss / cum_tgt_words), + cum_examples), file=sys.stderr) + + cum_loss = cum_batches = cum_tgt_words = 0. + valid_num += 1 + + print('begin validation ...', file=sys.stderr) + model.eval() + + # compute dev. ppl and bleu + + dev_loss = evaluate_loss(model, dev_data, cross_entropy_loss) + dev_ppl = np.exp(dev_loss) + + if args.valid_metric in ['rouge2_f', 'bleu', 'word_acc', 'sent_acc']: + dev_hyps = decode(model, dev_data) + dev_hyps = [hyps[0] for hyps in dev_hyps] + if args.valid_metric == 'bleu': + valid_metric = get_bleu([tgt for src, tgt in dev_data], dev_hyps) + elif args.valid_metric == 'rouge2_f': + valid_metric = get_rouge2f([tgt for src, tgt in dev_data], dev_hyps) + else: + valid_metric = get_acc([tgt for src, tgt in dev_data], dev_hyps, acc_type=args.valid_metric) + print('validation: iter %d, dev. ppl %f, dev. %s %f' % (train_iter, dev_ppl, args.valid_metric, valid_metric), + file=sys.stderr) + else: + valid_metric = -dev_ppl + print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), + file=sys.stderr) + + model.train() + + is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores) + is_better_than_last = len(hist_valid_scores) == 0 or valid_metric > hist_valid_scores[-1] + hist_valid_scores.append(valid_metric) + + if valid_num > args.save_model_after: + model_file = args.save_to + '.iter%d.bin' % train_iter + print('save model to [%s]' % model_file, file=sys.stderr) + model.save(model_file) + if is_better: + patience = 0 + drop_patience = 0 + best_model_iter = train_iter + + if valid_num > args.save_model_after: + print('save currently the best model ..', file=sys.stderr) + model_file_abs_path = os.path.abspath(model_file) + symlin_file_abs_path = os.path.abspath(args.save_to + '.bin') + os.system('ln -sf %s %s' % (model_file_abs_path, symlin_file_abs_path)) + else: + drop_patience += 1 + print('hit drop patience %d' % drop_patience, file=sys.stderr) + if drop_patience == args.drop_patience and args.lr_decay: + drop_patience = 0 + lr = optimizer.param_groups[0]['lr'] * args.lr_decay + vocab, model, optimizer, nll_loss, cross_entropy_loss, loss_ent = cont_training_halve_ent(args, lr) + print('decay learning rate to %.12f' % lr, file=sys.stderr) + + patience += 1 + print('hit patience %d' % patience, file=sys.stderr) + if patience == args.patience: + print('early stop!', file=sys.stderr) + print('the best model is from iteration [%d]' % best_model_iter, file=sys.stderr) + exit(0) + for train_ent_iter in range((train_multi_task_iter - 1) * args.switch, train_multi_task_iter * args.switch): + try: + ent_xs, ent_ys, ent_ls = next(ent_data_generator) + except: + ent_data_generator = data_iter(train_data_ent, batch_size=args.batch_size) + ent_xs, ent_ys, ent_ls = next(ent_data_generator) + + ent_xs_var = to_input_variable(ent_xs, vocab.src, cuda=args.cuda) + ent_ys_var = to_input_variable(ent_ys, vocab.src, cuda=args.cuda) + ent_ls_var = Variable(torch.LongTensor(ent_ls), volatile=False, requires_grad=False) + ent_ls_var = torch.squeeze(ent_ls_var) + if args.cuda: + ent_ls_var = ent_ls_var.cuda() + + batch_size = len(ent_xs) + ent_xs_len = [len(s) for s in ent_xs] + ent_ys_len = [len(s) for s in ent_ys] + + optimizer.zero_grad() + + # (tgt_sent_len, batch_size, tgt_vocab_size) -> (batch_size) + scores = model('entailment', ent_xs_var, ent_xs_len, ent_ys_var, ent_ys_len) + loss = loss_ent(scores, ent_ls_var) / batch_size + + loss.backward() + # clip gradient + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) + optimizer.step() + if train_ent_iter > 0 and ((train_ent_iter + 1 )* 100 / args.switch) % args.log_every == 0: + print('iter %d entailment classifier training complete .' % (train_ent_iter + 1)) + +def read_raml_train_data(data_file, temp): + train_data = dict() + num_pattern = re.compile('^(\d+) samples$') + with open(data_file) as f: + while True: + line = f.readline() + if line is None or line == '': + break + + assert line.startswith('***') + + src_sent = f.readline()[len('source: '):].strip() + tgt_num = int(num_pattern.match(f.readline().strip()).group(1)) + tgt_samples = [] + tgt_scores = [] + for i in xrange(tgt_num): + d = f.readline().strip().split(' ||| ') + if len(d) < 2: + continue + + tgt_sent = d[0].strip() + bleu_score = float(d[1]) + tgt_samples.append(tgt_sent) + tgt_scores.append(bleu_score / temp) + + tgt_scores = np.exp(tgt_scores) + tgt_scores = tgt_scores / np.sum(tgt_scores) + + tgt_entry = zip(tgt_samples, tgt_scores) + train_data[src_sent] = tgt_entry + + line = f.readline() + + return train_data + +def train_raml(args): + tau = args.temp + + train_data_src = read_corpus(args.train_src, source='src') + train_data_tgt = read_corpus(args.train_tgt, source='tgt') + train_data = zip(train_data_src, train_data_tgt) + + dev_data_src = read_corpus(args.dev_src, source='src') + dev_data_tgt = read_corpus(args.dev_tgt, source='tgt') + dev_data = zip(dev_data_src, dev_data_tgt) + + + assert args.load_model, 'You have to specify a pre-trained model' + vocab, model, optimizer, nll_loss, cross_entropy_loss = cont_training_raml(args) + + if args.raml_sample_mode == 'pre_sample': + # dict of (src, [tgt: (sent, prob)]) + print('read in raml training data...', file=sys.stderr, end='') + begin_time = time.time() + raml_samples = read_raml_train_data(args.raml_sample_file, temp=tau) + print('done[%d s].' % (time.time() - begin_time)) + elif args.raml_sample_mode.startswith('hamming_distance'): + print('sample from hamming distance payoff distribution') + payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(max(len(sent) for sent in train_data_tgt), + vocab_size=len(vocab.tgt) - 3, + tau=tau) + + train_iter = patience = drop_patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0 + report_weighted_loss = cum_weighted_loss = 0 + cum_examples = cum_batches = report_examples = epoch = valid_num = best_model_iter = 0 + hist_valid_scores = [] + train_time = begin_time = time.time() + print('begin RAML training') + sm_func = None + if args.smooth_bleu: + sm_func = SmoothingFunction().method3 + + if args.reward_type == 'entailment': + word_dic = {} + for line in open("/path/to/nli/word/dict").readlines(): + v, k = line.strip().split() + word_dic[v] = int(k) + input_encoder = encoder(137665, 300, 300, 0.01).cuda() + inter_atten = atten(300, 3, 0.01).cuda() + input_encoder.load_state_dict(torch.load('/path/to/input-encoder.pt', map_location=lambda storage, loc: storage)) + inter_atten.load_state_dict(torch.load('/path/to/inter-atten.pt', map_location=lambda storage, loc: storage)) + input_encoder.eval() + inter_atten.eval() + + while True: + epoch += 1 + for src_sents, tgt_sents in data_iter(train_data, batch_size=args.batch_size): + train_iter += 1 + + raml_src_sents = [] + raml_tgt_sents = [] + raml_tgt_weights = [] + + if args.raml_sample_mode == 'pre_sample': + for src_sent in src_sents: + tgt_samples_all = raml_samples[' '.join(src_sent)] + + if args.sample_size >= len(tgt_samples_all): + tgt_samples = tgt_samples_all + else: + tgt_samples_id = np.random.choice(range(1, len(tgt_samples_all)), size=args.sample_size - 1, replace=False) + tgt_samples = [tgt_samples_all[0]] + [tgt_samples_all[i] for i in tgt_samples_id] # make sure the ground truth y* is in the samples + + raml_src_sents.extend([src_sent] * len(tgt_samples)) + raml_tgt_sents.extend([[''] + sent.split(' ') + [''] for sent, weight in tgt_samples]) + raml_tgt_weights.extend([weight for sent, weight in tgt_samples]) + elif args.raml_sample_mode in ['hamming_distance', 'hamming_distance_impt_sample']: + for src_sent, tgt_sent in zip(src_sents, tgt_sents): + tgt_samples = [] # make sure the ground truth y* is in the samples + tgt_sent_len = len(tgt_sent) - 3 # remove and and ending period . + tgt_ref_tokens = tgt_sent[1:-1] + bleu_scores = [] + # sample an edit distances + e_samples = np.random.choice(range(tgt_sent_len + 1), p=payoff_prob[tgt_sent_len], size=args.sample_size, replace=True) + + # make sure the ground truth y* is in the samples + if args.raml_bias_groundtruth and (not 0 in e_samples): + e_samples[0] = 0 + + for i, e in enumerate(e_samples): + if e > 0: + # sample a new tgt_sent $y$ + old_word_pos = np.random.choice(range(1, tgt_sent_len + 1), size=e, replace=False) + new_words = [vocab.tgt.id2word[wid] for wid in np.random.randint(3, len(vocab.tgt), size=e)] + new_tgt_sent = list(tgt_sent) + for pos, word in zip(old_word_pos, new_words): + new_tgt_sent[pos] = word + else: + new_tgt_sent = list(tgt_sent) + + # if enable importance sampling, compute bleu score + if args.raml_sample_mode == 'hamming_distance_impt_sample': + if e > 0: + # remove and + if args.reward_type == 'bleu': + reward_score = sentence_bleu([tgt_ref_tokens], new_tgt_sent[1:-1], smoothing_function=sm_func) + elif args.reward_type == 'rouge2_f': + #print(tgt_ref_tokens) + #print(new_tgt_sent[1:-1]) + #print(aa) + reward_score = rouge_2(tgt_ref_tokens, new_tgt_sent[1:-1]) + elif args.reward_type == 'entailment': + reward_score_tgt = sentence_entailment(src_sent[1:-1], tgt_sent[1:-1], + word_dic, input_encoder, inter_atten) + reward_score_sample = sentence_entailment(src_sent[1:-1], new_tgt_sent[1:-1], + word_dic, input_encoder, inter_atten) + reward_score = min(reward_score_tgt, reward_score_sample) + bleu_scores.append(reward_score) + else: + bleu_scores.append(1.) + + # print('y: %s' % ' '.join(new_tgt_sent)) + tgt_samples.append(new_tgt_sent) + + # if enable importance sampling, compute importance weight + if args.raml_sample_mode == 'hamming_distance_impt_sample': + tgt_sample_weights = [math.exp(bleu_score / tau) / math.exp(-e / tau) for e, bleu_score in zip(e_samples, bleu_scores)] + normalizer = sum(tgt_sample_weights) + tgt_sample_weights = [w / normalizer for w in tgt_sample_weights] + else: + tgt_sample_weights = [1.] * args.sample_size + + if args.debug: + print('*' * 30) + print('Target: %s' % ' '.join(tgt_sent)) + for tgt_sample, e, bleu_score, weight in zip(tgt_samples, e_samples, bleu_scores, + tgt_sample_weights): + print('Sample: %s ||| e: %d ||| bleu: %f ||| weight: %f' % ( + ' '.join(tgt_sample), e, bleu_score, weight)) + print() + + raml_src_sents.extend([src_sent] * len(tgt_samples)) + raml_tgt_sents.extend(tgt_samples) + raml_tgt_weights.extend(tgt_sample_weights) + + src_sents_var = to_input_variable(raml_src_sents, vocab.src, cuda=args.cuda) + tgt_sents_var = to_input_variable(raml_tgt_sents, vocab.tgt, cuda=args.cuda) + weights_var = Variable(torch.FloatTensor(raml_tgt_weights), requires_grad=False) + if args.cuda: + weights_var = weights_var.cuda() + + batch_size = len(raml_src_sents) # batch_size = args.batch_size * args.sample_size + src_sents_len = [len(s) for s in raml_src_sents] + pred_tgt_word_num = sum(len(s[1:]) for s in raml_tgt_sents) # omitting leading `` + optimizer.zero_grad() + + # (tgt_sent_len, batch_size, tgt_vocab_size) + scores = model('summarization', src_sents_var, src_sents_len, tgt_sents_var[:-1], src_sents_len) + # (tgt_sent_len * batch_size, tgt_vocab_size) + log_scores = F.log_softmax(scores.view(-1, scores.size(2))) + # remove leading in tgt sent, which is not used as the target + flattened_tgt_sents = tgt_sents_var[1:].view(-1) + + # batch_size * tgt_sent_len + tgt_log_scores = torch.gather(log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1) + unweighted_loss = -tgt_log_scores * (1. - torch.eq(flattened_tgt_sents, 0).float()) + weighted_loss = unweighted_loss * weights_var.repeat(scores.size(0)) + weighted_loss = weighted_loss.sum() + weighted_loss_val = weighted_loss.data[0] + nll_loss_val = unweighted_loss.sum().data[0] + + loss = weighted_loss / batch_size + # nll_loss_val = nll_loss(log_scores, flattened_tgt_sents).data[0] + loss.backward() + # clip gradient + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_grad) + optimizer.step() + + report_weighted_loss += weighted_loss_val + cum_weighted_loss += weighted_loss_val + report_loss += nll_loss_val + cum_loss += nll_loss_val + report_tgt_words += pred_tgt_word_num + cum_tgt_words += pred_tgt_word_num + report_examples += batch_size + cum_examples += batch_size + cum_batches += batch_size + + if train_iter % args.log_every == 0: + print('epoch %d, iter %d, avg. loss %.2f, ' + 'avg. ppl %.2f cum. examples %d, ' + 'speed %.2f words/sec, time elapsed %.2f sec' % (epoch, train_iter, + report_weighted_loss / report_examples, + np.exp(report_loss / report_tgt_words), + cum_examples, + report_tgt_words / (time.time() - train_time), + time.time() - begin_time), + file=sys.stderr) + + train_time = time.time() + report_loss = report_weighted_loss = report_tgt_words = report_examples = 0. + + # perform validation + if train_iter % args.valid_niter == 0: + print('epoch %d, iter %d, cum. loss %.2f, ' + 'cum. ppl %.2f cum. examples %d' % (epoch, train_iter, + cum_weighted_loss / cum_batches, + np.exp(cum_loss / cum_tgt_words), + cum_examples), + file=sys.stderr) + + cum_loss = cum_weighted_loss = cum_batches = cum_tgt_words = 0. + valid_num += 1 + + print('begin validation ...', file=sys.stderr) + model.eval() + + # compute dev. ppl and bleu + + dev_loss = evaluate_loss(model, dev_data, cross_entropy_loss) + dev_ppl = np.exp(dev_loss) + + if args.valid_metric in ['rouge2_f', 'bleu', 'word_acc', 'sent_acc']: + dev_hyps = decode(model, dev_data) + dev_hyps = [hyps[0] for hyps in dev_hyps] + if args.valid_metric == 'bleu': + valid_metric = get_bleu([tgt for src, tgt in dev_data], dev_hyps) + elif args.valid_metric == 'rouge2_f': + valid_metric = get_rouge2f([tgt for src, tgt in dev_data], dev_hyps) + else: + valid_metric = get_acc([tgt for src, tgt in dev_data], dev_hyps, acc_type=args.valid_metric) + print('validation: iter %d, dev. ppl %f, dev. %s %f' % (train_iter, dev_ppl, args.valid_metric, valid_metric), + file=sys.stderr) + else: + valid_metric = -dev_ppl + print('validation: iter %d, dev. ppl %f' % (train_iter, dev_ppl), + file=sys.stderr) + + + model.train() + is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores) + is_better_than_last = len(hist_valid_scores) == 0 or valid_metric > hist_valid_scores[-1] + hist_valid_scores.append(valid_metric) + + if valid_num > args.save_model_after: + model_file = args.save_to + '.iter%d.bin' % train_iter + print('save model to [%s]' % model_file, file=sys.stderr) + model.save(model_file) + + if is_better: + drop_patience = 0 + patience = 0 + best_model_iter = train_iter + + if valid_num > args.save_model_after: + print('save currently the best model ..', file=sys.stderr) + model_file_abs_path = os.path.abspath(model_file) + symlin_file_abs_path = os.path.abspath(args.save_to + '.bin') + os.system('ln -sf %s %s' % (model_file_abs_path, symlin_file_abs_path)) + else: + drop_patience += 1 + print('hit drop patience %d' % drop_patience, file=sys.stderr) + if drop_patience == args.drop_patience and args.lr_decay: + drop_patience = 0 + lr = optimizer.param_groups[0]['lr'] * args.lr_decay + vocab, model, optimizer, nll_loss, cross_entropy_loss = cont_training_halve_raml(args, lr) + print('decay learning rate to %.12f' % lr, file=sys.stderr) + optimizer.param_groups[0]['lr'] = lr + + patience += 1 + print('hit patience %d' % patience, file=sys.stderr) + if patience == args.patience: + print('early stop!', file=sys.stderr) + print('the best model is from iteration [%d]' % best_model_iter, file=sys.stderr) + exit(0) + + +def get_bleu(references, hypotheses): + # compute BLEU + bleu_score = corpus_bleu([[ref[1:-1]] for ref in references], + [hyp[1:-1] for hyp in hypotheses]) + + return bleu_score + +def get_rouge2f(references, hypotheses): + # compute ROUGE-2 F1-SCORE + rouge2f_score = rouge_2_corpus([ref[1:-1] for ref in references], + [hyp[1:-1] for hyp in hypotheses]) + return rouge2f_score + +def get_acc(references, hypotheses, acc_type='word'): + assert acc_type == 'word_acc' or acc_type == 'sent_acc' + cum_acc = 0. + + for ref, hyp in zip(references, hypotheses): + ref = ref[1:-1] + hyp = hyp[1:-1] + if acc_type == 'word_acc': + acc = len([1 for ref_w, hyp_w in zip(ref, hyp) if ref_w == hyp_w]) / float(len(hyp) + 1e-6) + else: + acc = 1. if all(ref_w == hyp_w for ref_w, hyp_w in zip(ref, hyp)) else 0. + cum_acc += acc + + acc = cum_acc / len(hypotheses) + return acc + + +def decode(model, data, verbose=False): + """ + decode the dataset and compute sentence level acc. and BLEU. + """ + hypotheses = [] + begin_time = time.time() + + if type(data[0]) is tuple: + for src_sent, tgt_sent in data: + hyps = model.translate(src_sent) + hypotheses.append(hyps) + + if verbose: + print('*' * 50) + print('Source: ', ' '.join(src_sent)) + print('Target: ', ' '.join(tgt_sent)) + print('Top Hypothesis: ', ' '.join(hyps[0])) + else: + for src_sent in data: + hyps = model.translate(src_sent) + hypotheses.append(hyps) + + if verbose: + print('*' * 50) + print('Source: ', ' '.join(src_sent)) + print('Top Hypothesis: ', ' '.join(hyps[0])) + + elapsed = time.time() - begin_time + + print('decoded %d examples, took %d s' % (len(data), elapsed), file=sys.stderr) + + return hypotheses + + +def compute_lm_prob(args): + """ + given source-target sentence pairs, compute ppl and log-likelihood + """ + test_data_src = read_corpus(args.test_src, source='src') + test_data_tgt = read_corpus(args.test_tgt, source='tgt') + test_data = zip(test_data_src, test_data_tgt) + + if args.load_model: + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + saved_args = params['args'] + state_dict = params['state_dict'] + + model = NMT(saved_args, vocab) + model.load_state_dict(state_dict) + else: + vocab = torch.load(args.vocab) + model = NMT(args, vocab) + + model.eval() + + if args.cuda: + model = model.cuda() + + f = open(args.save_to_file, 'w') + for src_sent, tgt_sent in test_data: + src_sents = [src_sent] + tgt_sents = [tgt_sent] + + batch_size = len(src_sents) + src_sents_len = [len(s) for s in src_sents] + pred_tgt_word_nums = [len(s[1:]) for s in tgt_sents] # omitting leading `` + + # (sent_len, batch_size) + src_sents_var = to_input_variable(src_sents, model.vocab.src, cuda=args.cuda, is_test=True) + tgt_sents_var = to_input_variable(tgt_sents, model.vocab.tgt, cuda=args.cuda, is_test=True) + + # (tgt_sent_len, batch_size, tgt_vocab_size) + scores = model(src_sents_var, src_sents_len, tgt_sents_var[:-1]) + # (tgt_sent_len * batch_size, tgt_vocab_size) + log_scores = F.log_softmax(scores.view(-1, scores.size(2))) + # remove leading in tgt sent, which is not used as the target + # (batch_size * tgt_sent_len) + flattened_tgt_sents = tgt_sents_var[1:].view(-1) + # (batch_size * tgt_sent_len) + tgt_log_scores = torch.gather(log_scores, 1, flattened_tgt_sents.unsqueeze(1)).squeeze(1) + # 0-index is the symbol + tgt_log_scores = tgt_log_scores * (1. - torch.eq(flattened_tgt_sents, 0).float()) + # (tgt_sent_len, batch_size) + tgt_log_scores = tgt_log_scores.view(-1, batch_size) # .permute(1, 0) + # (batch_size) + tgt_sent_scores = tgt_log_scores.sum(dim=0).squeeze() + tgt_sent_word_scores = [tgt_sent_scores[i].data[0] / pred_tgt_word_nums[i] for i in xrange(batch_size)] + + for src_sent, tgt_sent, score in zip(src_sents, tgt_sents, tgt_sent_word_scores): + f.write('%s ||| %s ||| %f\n' % (' '.join(src_sent), ' '.join(tgt_sent), score)) + + f.close() + + +def test(args): + test_data_src = read_corpus(args.test_src, source='src') + test_data_tgt = read_corpus(args.test_tgt, source='tgt') + test_data = zip(test_data_src, test_data_tgt) + + if args.load_model: + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + saved_args = params['args'] + state_dict = params['state_dict'] + + model = NMT(saved_args, vocab) + model.load_state_dict(state_dict) + else: + vocab = torch.load(args.vocab) + model = NMT(args, vocab) + + model.eval() + + if args.cuda: + model = model.cuda() + + hypotheses = decode(model, test_data) + top_hypotheses = [hyps[0] for hyps in hypotheses] + + bleu_score = get_bleu([tgt for src, tgt in test_data], top_hypotheses) + word_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses, 'word_acc') + sent_acc = get_acc([tgt for src, tgt in test_data], top_hypotheses, 'sent_acc') + print('Corpus Level BLEU: %f, word level acc: %f, sentence level acc: %f' % (bleu_score, word_acc, sent_acc), + file=sys.stderr) + + if args.save_to_file: + print('save decoding results to %s' % args.save_to_file, file=sys.stderr) + with open(args.save_to_file, 'w') as f: + for hyps in hypotheses: + f.write(' '.join(hyps[0][1:-1]) + '\n') + + if args.save_nbest: + nbest_file = args.save_to_file + '.nbest' + print('save nbest decoding results to %s' % nbest_file, file=sys.stderr) + with open(nbest_file, 'w') as f: + for src_sent, tgt_sent, hyps in zip(test_data_src, test_data_tgt, hypotheses): + print('Source: %s' % ' '.join(src_sent), file=f) + print('Target: %s' % ' '.join(tgt_sent), file=f) + print('Hypotheses:', file=f) + for i, hyp in enumerate(hyps, 1): + print('[%d] %s' % (i, ' '.join(hyp)), file=f) + print('*' * 30, file=f) + + +def interactive(args): + assert args.load_model, 'You have to specify a pre-trained model' + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + saved_args = params['args'] + state_dict = params['state_dict'] + + model = NMT(saved_args, vocab) + model.load_state_dict(state_dict) + + model.eval() + + if args.cuda: + model = model.cuda() + + while True: + src_sent = raw_input('Source Sentence:') + src_sent = src_sent.strip().split(' ') + hyps = model.translate(src_sent) + for i, hyp in enumerate(hyps, 1): + print('Hypothesis #%d: %s' % (i, ' '.join(hyp))) + + +def sample(args): + train_data_src = read_corpus(args.train_src, source='src') + train_data_tgt = read_corpus(args.train_tgt, source='tgt') + train_data = zip(train_data_src, train_data_tgt) + + if args.load_model: + print('load model from [%s]' % args.load_model, file=sys.stderr) + params = torch.load(args.load_model, map_location=lambda storage, loc: storage) + vocab = params['vocab'] + opt = params['args'] + state_dict = params['state_dict'] + + model = NMT(opt, vocab) + model.load_state_dict(state_dict) + else: + vocab = torch.load(args.vocab) + model = NMT(args, vocab) + + model.eval() + + if args.cuda: + model = model.cuda() + + print('begin sampling') + + check_every = 10 + train_iter = cum_samples = 0 + train_time = time.time() + for src_sents, tgt_sents in data_iter(train_data, batch_size=args.batch_size): + train_iter += 1 + samples = model.sample(src_sents, sample_size=args.sample_size, to_word=True) + cum_samples += sum(len(sample) for sample in samples) + + if train_iter > 0 and train_iter % check_every == 0: + elapsed = time.time() - train_time + print('sampling speed: %d/s' % (cum_samples / elapsed), file=sys.stderr) + cum_samples = 0 + train_time = time.time() + + for i, tgt_sent in enumerate(tgt_sents): + print('*' * 80) + print('target:' + ' '.join(tgt_sent)) + tgt_samples = samples[i] + print('samples:') + for sid, sample in enumerate(tgt_samples, 1): + print('[%d] %s' % (sid, ' '.join(sample[1:-1]))) + print('*' * 80) + + +if __name__ == '__main__': + args = init_config() + print(args, file=sys.stderr) + + if args.mode == 'train': + train(args) + elif args.mode == 'raml_train': + train_raml(args) + elif args.mode == 'sample': + sample(args) + elif args.mode == 'test': + test(args) + elif args.mode == 'prob': + compute_lm_prob(args) + elif args.mode == 'interactive': + interactive(args) + else: + raise RuntimeError('unknown mode') diff --git a/process_samples.py b/process_samples.py new file mode 100644 index 0000000..2543a5b --- /dev/null +++ b/process_samples.py @@ -0,0 +1,307 @@ +from __future__ import print_function +from nltk.translate.bleu_score import sentence_bleu +from nltk.translate.bleu_score import SmoothingFunction +import sys +import re +import argparse +import torch +from util import read_corpus +import numpy as np +from scipy.misc import comb +from vocab import Vocab, VocabEntry +import math + + +def is_valid_sample(sent): + tokens = sent.split(' ') + return len(tokens) >= 1 and len(tokens) < 50 + + +def sample_from_model(args): + para_data = args.parallel_data + sample_file = args.sample_file + output = args.output + + tgt_sent_pattern = re.compile('^\[(\d+)\] (.*?)$') + para_data = [l.strip().split(' ||| ') for l in open(para_data)] + + f_out = open(output, 'w') + f = open(sample_file) + f.readline() + for src_sent, tgt_sent in para_data: + line = f.readline().strip() + assert line.startswith('****') + line = f.readline().strip() + print(line) + assert line.startswith('target:') + + tgt_sent2 = line[len('target:'):] + assert tgt_sent == tgt_sent2 + + line = f.readline().strip() # samples + + tgt_sent = ' '.join(tgt_sent.split(' ')[1:-1]) + tgt_samples = set() + for i in xrange(1, 101): + line = f.readline().rstrip('\n') + m = tgt_sent_pattern.match(line) + + assert m, line + assert int(m.group(1)) == i + + sampled_tgt_sent = m.group(2).strip() + + if is_valid_sample(sampled_tgt_sent): + tgt_samples.add(sampled_tgt_sent) + + line = f.readline().strip() + assert line.startswith('****') + + tgt_samples.add(tgt_sent) + tgt_samples = list(tgt_samples) + + assert len(tgt_samples) > 0 + + tgt_ref_tokens = tgt_sent.split(' ') + bleu_scores = [] + for tgt_sample in tgt_samples: + bleu_score = sentence_bleu([tgt_ref_tokens], tgt_sample.split(' ')) + bleu_scores.append(bleu_score) + + tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True) + + print('%d samples' % len(tgt_samples)) + + print('*' * 50, file=f_out) + print('source: ' + src_sent, file=f_out) + print('%d samples' % len(tgt_samples), file=f_out) + for i in tgt_ranks: + print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out) + print('*' * 50, file=f_out) + + f_out.close() + + +def get_new_ngram(ngram, n, vocab): + """ + replace ngram `ngram` with a newly sampled ngram of the same length + """ + + new_ngram_wids = [np.random.randint(3, len(vocab)) for i in xrange(n)] + new_ngram = [vocab.id2word[wid] for wid in new_ngram_wids] + + return new_ngram + + +def sample_ngram(args): + src_sents = read_corpus(args.src, 'src') + tgt_sents = read_corpus(args.tgt, 'src') # do not read in and + f_out = open(args.output, 'w') + + vocab = torch.load(args.vocab) + tgt_vocab = vocab.tgt + + smooth_bleu = args.smooth_bleu + sm_func = None + if smooth_bleu: + sm_func = SmoothingFunction().method3 + + for src_sent, tgt_sent in zip(src_sents, tgt_sents): + src_sent = ' '.join(src_sent) + + tgt_len = len(tgt_sent) + tgt_samples = [] + tgt_samples_distort_rates = [] # how many unigrams are replaced + + # generate 100 samples + + # append itself + tgt_samples.append(tgt_sent) + tgt_samples_distort_rates.append(0) + + for sid in xrange(args.sample_size - 1): + n = np.random.randint(1, min(tgt_len, args.max_ngram_size + 1)) # we do not replace the last token: it must be a period! + + idx = np.random.randint(tgt_len - n) + ngram = tgt_sent[idx: idx+n] + new_ngram = get_new_ngram(ngram, n, tgt_vocab) + + sampled_tgt_sent = list(tgt_sent) + sampled_tgt_sent[idx: idx+n] = new_ngram + + # compute the probability of this sample + # prob = 1. / args.max_ngram_size * 1. / (tgt_len - 1 + n) * 1 / (len(tgt_vocab) ** n) + + tgt_samples.append(sampled_tgt_sent) + tgt_samples_distort_rates.append(n) + + # compute bleu scores or edit distances and rank the samples by bleu scores + rewards = [] + for tgt_sample, tgt_sample_distort_rate in zip(tgt_samples, tgt_samples_distort_rates): + if args.reward == 'bleu': + reward = sentence_bleu([tgt_sent], tgt_sample, smoothing_function=sm_func) + else: + reward = -tgt_sample_distort_rate + + rewards.append(reward) + + tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: rewards[i], reverse=True) + # convert list of tokens into a string + tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples] + + print('*' * 50, file=f_out) + print('source: ' + src_sent, file=f_out) + print('%d samples' % len(tgt_samples), file=f_out) + for i in tgt_ranks: + print('%s ||| %f' % (tgt_samples[i], rewards[i]), file=f_out) + print('*' * 50, file=f_out) + + f_out.close() + + +def sample_ngram_adapt(args): + src_sents = read_corpus(args.src, 'src') + tgt_sents = read_corpus(args.tgt, 'src') # do not read in and + f_out = open(args.output, 'w') + + vocab = torch.load(args.vocab) + tgt_vocab = vocab.tgt + + max_len = max([len(tgt_sent) for tgt_sent in tgt_sents]) + 1 + + for src_sent, tgt_sent in zip(src_sents, tgt_sents): + src_sent = ' '.join(src_sent) + + tgt_len = len(tgt_sent) + tgt_samples = [] + + # generate 100 samples + + # append itself + tgt_samples.append(tgt_sent) + + for sid in xrange(args.sample_size - 1): + max_n = min(tgt_len - 1, 4) + bias_n = int(max_n * tgt_len / max_len) + 1 + assert 1 <= bias_n <= 4, 'bias_n={}, not in [1,4], max_n={}, tgt_len={}, max_len={}'.format(bias_n, max_n, tgt_len, max_len) + + p = [1.0/(max_n + 5)] * max_n + p[bias_n - 1] = 1 - p[0] * (max_n - 1) + assert abs(sum(p) - 1) < 1e-10, 'sum(p) != 1' + + n = np.random.choice(np.arange(1, int(max_n + 1)), p=p) # we do not replace the last token: it must be a period! + assert n < tgt_len, 'n={}, tgt_len={}'.format(n, tgt_len) + + idx = np.random.randint(tgt_len - n) + ngram = tgt_sent[idx: idx+n] + new_ngram = get_new_ngram(ngram, n, tgt_vocab) + + sampled_tgt_sent = list(tgt_sent) + sampled_tgt_sent[idx: idx+n] = new_ngram + + tgt_samples.append(sampled_tgt_sent) + + # compute bleu scores and rank the samples by bleu scores + bleu_scores = [] + for tgt_sample in tgt_samples: + bleu_score = sentence_bleu([tgt_sent], tgt_sample) + bleu_scores.append(bleu_score) + + tgt_ranks = sorted(range(len(tgt_samples)), key=lambda i: bleu_scores[i], reverse=True) + # convert list of tokens into a string + tgt_samples = [' '.join(tgt_sample) for tgt_sample in tgt_samples] + + print('*' * 50, file=f_out) + print('source: ' + src_sent, file=f_out) + print('%d samples' % len(tgt_samples), file=f_out) + for i in tgt_ranks: + print('%s ||| %f' % (tgt_samples[i], bleu_scores[i]), file=f_out) + print('*' * 50, file=f_out) + + f_out.close() + + +def sample_from_hamming_distance_payoff_distribution(args): + src_sents = read_corpus(args.src, 'src') + tgt_sents = read_corpus(args.tgt, 'src') # do not read in and + f_out = open(args.output, 'w') + + vocab = torch.load(args.vocab) + tgt_vocab = vocab.tgt + + payoff_prob, Z_qs = generate_hamming_distance_payoff_distribution(max(len(sent) for sent in tgt_sents), + vocab_size=len(vocab.tgt), + tau=args.temp) + + for src_sent, tgt_sent in zip(src_sents, tgt_sents): + tgt_samples = [] # make sure the ground truth y* is in the samples + tgt_sent_len = len(tgt_sent) - 3 # remove and and ending period . + tgt_ref_tokens = tgt_sent[1:-1] + bleu_scores = [] + + # sample an edit distances + e_samples = np.random.choice(range(tgt_sent_len + 1), p=payoff_prob[tgt_sent_len], size=args.sample_size, + replace=True) + + for i, e in enumerate(e_samples): + if e > 0: + # sample a new tgt_sent $y$ + old_word_pos = np.random.choice(range(1, tgt_sent_len + 1), size=e, replace=False) + new_words = [vocab.tgt.id2word[wid] for wid in np.random.randint(3, len(vocab.tgt), size=e)] + new_tgt_sent = list(tgt_sent) + for pos, word in zip(old_word_pos, new_words): + new_tgt_sent[pos] = word + + bleu_score = sentence_bleu([tgt_ref_tokens], new_tgt_sent[1:-1]) + bleu_scores.append(bleu_score) + else: + new_tgt_sent = list(tgt_sent) + bleu_scores.append(1.) + + # print('y: %s' % ' '.join(new_tgt_sent)) + tgt_samples.append(new_tgt_sent) + + +def generate_hamming_distance_payoff_distribution(max_sent_len, vocab_size, tau=1.): + """compute the q distribution for Hamming Distance (substitution only) as in the RAML paper""" + probs = dict() + Z_qs = dict() + for sent_len in xrange(1, max_sent_len + 1): + counts = [1.] # e = 0, count = 1 + for e in xrange(1, sent_len + 1): + # apply the rescaling trick as in https://gist.github.com/norouzi/8c4d244922fa052fa8ec18d8af52d366 + count = comb(sent_len, e) * math.exp(-e / tau) * ((vocab_size - 1) ** (e - e / tau)) + counts.append(count) + + Z_qs[sent_len] = Z_q = sum(counts) + prob = [count / Z_q for count in counts] + probs[sent_len] = prob + + # print('sent_len=%d, %s' % (sent_len, prob)) + + return probs, Z_qs + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--mode', choices=['sample_from_model', 'sample_ngram_adapt', 'sample_ngram'], required=True) + parser.add_argument('--vocab', type=str) + parser.add_argument('--src', type=str) + parser.add_argument('--tgt', type=str) + parser.add_argument('--parallel_data', type=str) + parser.add_argument('--sample_file', type=str) + parser.add_argument('--output', type=str, required=True) + parser.add_argument('--sample_size', type=int, default=100) + parser.add_argument('--reward', choices=['bleu', 'edit_dist'], default='bleu') + parser.add_argument('--max_ngram_size', type=int, default=4) + parser.add_argument('--temp', type=float, default=0.5) + parser.add_argument('--smooth_bleu', action='store_true', default=False) + + args = parser.parse_args() + + if args.mode == 'sample_ngram': + sample_ngram(args) + elif args.mode == 'sample_from_model': + sample_from_model(args) + elif args.mode == 'sample_ngram_adapt': + sample_ngram_adapt(args) \ No newline at end of file diff --git a/rougescore.py b/rougescore.py new file mode 100644 index 0000000..e60be28 --- /dev/null +++ b/rougescore.py @@ -0,0 +1,148 @@ +from __future__ import division +import collections + +import six + +def _ngrams(words, n): + queue = collections.deque(maxlen=n) + for w in words: + queue.append(w) + if len(queue) == n: + yield tuple(queue) + +def _ngram_counts(words, n): + return collections.Counter(_ngrams(words, n)) + +def _ngram_count(words, n): + return max(len(words) - n + 1, 0) + +def _counter_overlap(counter1, counter2): + result = 0 + for k, v in six.iteritems(counter1): + result += min(v, counter2[k]) + return result + +def _safe_divide(numerator, denominator): + if denominator > 0: + return numerator / denominator + else: + return 0 + +def _safe_f1(matches, recall_total, precision_total, alpha=0.5): + recall_score = _safe_divide(matches, recall_total) + precision_score = _safe_divide(matches, precision_total) + denom = (1.0 - alpha) * precision_score + alpha * recall_score + if denom > 0.0: + return (precision_score * recall_score) / denom + else: + return 0.0 + +def rouge_n(peer, model, n): + """ + Compute the ROUGE-N score of a peer with respect to one or more models, for + a given value of `n`. + """ + #matches = 0 + #recall_total = 0 + peer_counter = _ngram_counts(peer, n) + ''' + for model in models: + model_counter = _ngram_counts(model, n) + matches += _counter_overlap(peer_counter, model_counter) + recall_total += _ngram_count(model, n) + ''' + model_counter = _ngram_counts(model, n) + matches = _counter_overlap(peer_counter, model_counter) + recall_total = _ngram_count(model, n) + precision_total = _ngram_count(peer, n) + #precision_total = len(models) * _ngram_count(peer, n) + #print matches, recall_total, precision_total + return _safe_f1(matches, recall_total, precision_total) + +def rouge_1(peer, model): + """ + Compute the ROUGE-1 (unigram) score of a peer with respect to one or more + models. + """ + return rouge_n(peer, model, 1) + +def rouge_2(peer, model): + """ + Compute the ROUGE-2 (bigram) score of a peer with respect to one or more + models. + """ + return rouge_n(peer, model, 2) + +def rouge_3(peer, model): + """ + Compute the ROUGE-3 (trigram) score of a peer with respect to one or more + models. + """ + return rouge_n(peer, model, 3) + +def lcs(a, b): + """ + Compute the length of the longest common subsequence between two sequences. + Time complexity: O(len(a) * len(b)) + Space complexity: O(min(len(a), len(b))) + """ + # This is an adaptation of the standard LCS dynamic programming algorithm + # tweaked for lower memory consumption. + # Sequence a is laid out along the rows, b along the columns. + # Minimize number of columns to minimize required memory + if len(a) < len(b): + a, b = b, a + # Sequence b now has the minimum length + # Quit early if one sequence is empty + if len(b) == 0: + return 0 + # Use a single buffer to store the counts for the current row, and + # overwrite it on each pass + row = [0] * len(b) + for ai in a: + left = 0 + diag = 0 + for j, bj in enumerate(b): + up = row[j] + if ai == bj: + value = diag + 1 + else: + value = max(left, up) + row[j] = value + left = value + diag = up + # Return the last cell of the last row + return left + +def rouge_l(peer, models): + """ + Compute the ROUGE-L score of a peer with respect to one or more models. + """ + matches = 0 + recall_total = 0 + for model in models: + matches += lcs(model, peer) + recall_total += len(model) + precision_total = len(models) * len(peer) + return _safe_f1(matches, recall_total, precision_total, alpha) + +def rouge_1_corpus(peers, models): + curpus_size = len(peers) + rouge_score = 0 + for (peer, model) in zip(peers, models): + rouge_score += rouge_1(peer, model) + #print rouge_1(peer, model) + #print "========" + return rouge_score / curpus_size + +def rouge_2_corpus(peers, models): + curpus_size = len(peers) + rouge_score = 0 + for (peer, model) in zip(peers, models): + #print rouge_2(peer, model) + rouge_score += rouge_2(peer, model) + #print "=======" + return rouge_score / curpus_size + +if __name__ == '__main__': + pass diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..f7037fc --- /dev/null +++ b/test.sh @@ -0,0 +1,7 @@ +CUDA_VISIBLE_DEVICES=0 python nmt.py \ + --cuda \ + --mode test \ + --load_model /path/to/trained/model \ + --save_to_file output \ + --test_src /path/to/valid/article \ + --test_tgt /path/to/valid/title diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..bc4f3b0 --- /dev/null +++ b/train.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +CUDA_VISIBLE_DEVICES=0 python nmt.py \ + --cuda \ + --mode train \ + --vocab "path/to/train/vocab" \ + --save_to "path/to/saved/models" \ + --train_src "path/to/train/article" \ + --train_tgt "path/to/train/title" \ + --dev_src "path/to/dev/article" \ + --dev_tgt "path/to/dev/title" \ + --test_src "path/to/test/article" \ + --test_tgt "path/to/test/title" \ + --train_ent_x "path/to/train/entailment/premise" \ + --train_ent_y "path/to/train/entailment/hypothesis" \ + --train_ent_label "path/to/train/entailment/label" \ + 2>&1 | tee -a train.mtl.log diff --git a/util.py b/util.py new file mode 100644 index 0000000..bb82ead --- /dev/null +++ b/util.py @@ -0,0 +1,65 @@ +from collections import defaultdict +import numpy as np + +def map_ent_label(labels): + ent_label_dict = {'entailment': 0, 'neutral': 1, 'contradiction': 2} + new_labels = [[ent_label_dict[label[0]]] for label in labels] + return new_labels + +def read_corpus(file_path, source): + data = [] + for line in open(file_path): + sent = line.strip().split(' ') + # only append and to the target sentence + if source == 'tgt': + sent = [''] + sent + [''] + data.append(sent) + + return data + + +def batch_slice(data, batch_size, sort=True, ent=False): + if len(data[0]) == 3: + ent = True + batch_num = int(np.ceil(len(data) / float(batch_size))) + for i in xrange(batch_num): + cur_batch_size = batch_size if i < batch_num - 1 else len(data) - batch_size * i + src_sents = [data[i * batch_size + b][0] for b in range(cur_batch_size)] + tgt_sents = [data[i * batch_size + b][1] for b in range(cur_batch_size)] + if ent: + labels = [data[i * batch_size + b][2] for b in range(cur_batch_size)] + + if sort: + src_ids = sorted(range(cur_batch_size), key=lambda src_id: len(src_sents[src_id]), reverse=True) + src_sents = [src_sents[src_id] for src_id in src_ids] + tgt_sents = [tgt_sents[src_id] for src_id in src_ids] + if ent: + labels = [labels[src_id] for src_id in src_ids] + + if ent: + yield src_sents, tgt_sents, labels + else: + yield src_sents, tgt_sents + + +def data_iter(data, batch_size, shuffle=True): + """ + randomly permute data, then sort by source length, and partition into batches + ensure that the length of source sentences in each batch is decreasing + """ + + buckets = defaultdict(list) + for pair in data: + src_sent = pair[0] + buckets[len(src_sent)].append(pair) + + batched_data = [] + for src_len in buckets: + tuples = buckets[src_len] + if shuffle: np.random.shuffle(tuples) + batched_data.extend(list(batch_slice(tuples, batch_size))) + + if shuffle: + np.random.shuffle(batched_data) + for batch in batched_data: + yield batch diff --git a/vocab.py b/vocab.py new file mode 100644 index 0000000..55eb059 --- /dev/null +++ b/vocab.py @@ -0,0 +1,105 @@ +from __future__ import print_function +import argparse +from collections import Counter +from itertools import chain + +import torch + +from util import read_corpus + + +class VocabEntry(object): + def __init__(self): + self.word2id = dict() + self.unk_id = 3 + self.word2id[''] = 0 + self.word2id[''] = 1 + self.word2id[''] = 2 + self.word2id[''] = 3 + + self.id2word = {v: k for k, v in self.word2id.iteritems()} + + def __getitem__(self, word): + return self.word2id.get(word, self.unk_id) + + def __contains__(self, word): + return word in self.word2id + + def __setitem__(self, key, value): + raise ValueError('vocabulary is readonly') + + def __len__(self): + return len(self.word2id) + + def __repr__(self): + return 'Vocabulary[size=%d]' % len(self) + + def id2word(self, wid): + return self.id2word[wid] + + def add(self, word): + if word not in self: + wid = self.word2id[word] = len(self) + self.id2word[wid] = word + return wid + else: + return self[word] + + @staticmethod + def from_corpus(corpus, size, remove_singleton=True): + vocab_entry = VocabEntry() + + word_freq = Counter(chain(*corpus)) + non_singletons = [w for w in word_freq if word_freq[w] > 1] + print('number of word types: %d, number of word types w/ frequency > 1: %d' % (len(word_freq), + len(non_singletons))) + + top_k_words = sorted(word_freq.keys(), reverse=True, key=word_freq.get)[:size] + + for word in top_k_words: + if len(vocab_entry) < size: + if not (word_freq[word] == 1 and remove_singleton): + vocab_entry.add(word) + + return vocab_entry + + +class Vocab(object): + def __init__(self, src_sents, tgt_sents, src_vocab_size, tgt_vocab_size, remove_singleton=True): + assert len(src_sents) == len(tgt_sents) + + print('initialize source vocabulary ..') + self.src = VocabEntry.from_corpus(src_sents, src_vocab_size, remove_singleton=remove_singleton) + + print('initialize target vocabulary ..') + self.tgt = VocabEntry.from_corpus(tgt_sents, tgt_vocab_size, remove_singleton=remove_singleton) + + def __repr__(self): + return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--src_vocab_size', default=500000, type=int, help='source vocabulary size') + parser.add_argument('--tgt_vocab_size', default=500000, type=int, help='target vocabulary size') + parser.add_argument('--include_singleton', action='store_true', default=False, help='whether to include singleton' + 'in the vocabulary (default=False)') + + parser.add_argument('--train_src', type=str, required=True, help='file of source sentences') + parser.add_argument('--train_tgt', type=str, required=True, help='file of target sentences') + + parser.add_argument('--output', default='vocab.bin', type=str, help='output vocabulary file') + + args = parser.parse_args() + + print('read in source sentences: %s' % args.train_src) + print('read in target sentences: %s' % args.train_tgt) + + src_sents = read_corpus(args.train_src, source='src') + tgt_sents = read_corpus(args.train_tgt, source='tgt') + + vocab = Vocab(src_sents, tgt_sents, args.src_vocab_size, args.tgt_vocab_size, remove_singleton=not args.include_singleton) + print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt))) + + torch.save(vocab, args.output) + print('vocabulary saved to %s' % args.output) diff --git a/vocab.sh b/vocab.sh new file mode 100644 index 0000000..e31b2de --- /dev/null +++ b/vocab.sh @@ -0,0 +1,4 @@ +python vocab.py \ + --include_singleton \ + --train_src /path/to/train/article \ + --train_tgt /path/to/train/title