diff --git a/seq2seq_tf2/batcher.py b/seq2seq_tf2/batcher.py index dbeec36..db42a3f 100644 --- a/seq2seq_tf2/batcher.py +++ b/seq2seq_tf2/batcher.py @@ -116,16 +116,10 @@ def abstract_to_sents(abstract): sents = [] while True: try: - print('SENTENCE_START is ', SENTENCE_START) - print('in abstract is ', abstract) start_p = abstract.index(SENTENCE_START, cur) - print('start_p is ', start_p) end_p = abstract.index(SENTENCE_END, start_p + 1) - print('end_p is ', end_p) cur = end_p + len(SENTENCE_END) - print('cur is ', cur) sents.append(abstract[start_p + len(SENTENCE_START): end_p]) - print('sents is ', sents) except ValueError as e: # no more sentences return sents @@ -162,7 +156,7 @@ def get_dec_inp_targ_seqs(sequence, max_len, start_id, stop_id): # return parsed_example -def example_generator(filenames_1, filenames_2, vocab_path, vocab_size, max_enc_len, max_dec_len, mode): +def example_generator(filenames_1, filenames_2, vocab, max_enc_len, max_dec_len, mode, batch_size): dataset_1 = tf.data.TextLineDataset(filenames_1) dataset_2 = tf.data.TextLineDataset(filenames_2) @@ -170,9 +164,6 @@ def example_generator(filenames_1, filenames_2, vocab_path, vocab_size, max_enc_ if mode == "train": train_dataset = train_dataset.shuffle(10, reshuffle_each_iteration=True).repeat() - vocab = Vocab(vocab_path, vocab_size) - # print('vocab is {}'.format(vocab.word2id)) - for raw_record in train_dataset: article = raw_record[0].numpy().decode("utf-8") # print('article is ', article) @@ -224,13 +215,15 @@ def example_generator(filenames_1, filenames_2, vocab_path, vocab_size, max_enc_ "abstract": abstract, "abstract_sents": abstract } - # print('output is ', output) - yield output + if mode == "test": + for _ in range(batch_size): + yield output + else: + yield output -def batch_generator(generator, filenames_1, filenames_2, vocab_path, vocab_size, max_enc_len, max_dec_len, batch_size, mode): - dataset = tf.data.Dataset.from_generator(generator, - args=[filenames_1, filenames_2, vocab_path, vocab_size, max_enc_len, max_dec_len, mode], +def batch_generator(generator, filenames_1, filenames_2, vocab, max_enc_len, max_dec_len, batch_size, mode): + dataset = tf.data.Dataset.from_generator(lambda: generator(filenames_1, filenames_2, vocab, max_enc_len, max_dec_len, mode, batch_size), output_types={ "enc_len": tf.int32, "enc_input": tf.int32, @@ -296,9 +289,9 @@ def update(entry): return dataset -def batcher(filenames_1, filenames_2, vocab_path, hpm): +def batcher(filenames_1, filenames_2, vocab, hpm): # filenames = glob.glob("{}/*.tfrecords".format(data_path)) - dataset = batch_generator(example_generator, filenames_1, filenames_2, vocab_path, hpm["vocab_size"], hpm["max_enc_len"], + dataset = batch_generator(example_generator, filenames_1, filenames_2, vocab, hpm["max_enc_len"], hpm["max_dec_len"], hpm["batch_size"], hpm["mode"]) return dataset diff --git a/seq2seq_tf2/layers.py b/seq2seq_tf2/layers.py index 8b62df7..c19bcf5 100644 --- a/seq2seq_tf2/layers.py +++ b/seq2seq_tf2/layers.py @@ -4,31 +4,29 @@ class Encoder(tf.keras.layers.Layer): - def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz): + def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz, embedding_matrix): super(Encoder, self).__init__() self.batch_sz = batch_sz self.enc_units = enc_units - # embedding_matrix = load_word2vec(vocab_size) - # self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, - # weights=[embedding_matrix], - # trainable=False) - self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) + self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, + weights=[embedding_matrix], + trainable=False) self.gru = tf.keras.layers.GRU(self.enc_units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform') - # self.bigru = tf.keras.layers.Bidirectional(self.gru, merge_mode='concat') + self.bigru = tf.keras.layers.Bidirectional(self.gru, merge_mode='concat') def call(self, x, hidden): x = self.embedding(x) - # hidden = tf.split(hidden, num_or_size_splits=2, axis=1) - # output, forward_state, backward_state = self.bigru(x, initial_state=hidden) - # state = tf.concat([forward_state, backward_state], axis=1) - output, state = self.gru(x, initial_state=hidden) + hidden = tf.split(hidden, num_or_size_splits=2, axis=1) + output, forward_state, backward_state = self.bigru(x, initial_state=hidden) + state = tf.concat([forward_state, backward_state], axis=1) + # output, state = self.gru(x, initial_state=hidden) return output, state def initialize_hidden_state(self): - return tf.zeros((self.batch_sz, self.enc_units)) + return tf.zeros((self.batch_sz, 2*self.enc_units)) class BahdanauAttention(tf.keras.layers.Layer): @@ -60,21 +58,19 @@ def call(self, query, values): class Decoder(tf.keras.layers.Layer): - def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz): + def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, embedding_matrix): super(Decoder, self).__init__() self.batch_sz = batch_sz self.dec_units = dec_units - # embedding_matrix = load_word2vec(vocab_size) - # self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, - # weights=[embedding_matrix], - # trainable=False) - self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) + self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, + weights=[embedding_matrix], + trainable=False) self.gru = tf.keras.layers.GRU(self.dec_units, return_sequences=True, return_state=True, recurrent_initializer='glorot_uniform') self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax) - # self.fc = tf.nn.dropout(0.5) + self.fc = tf.keras.layers.Dropout(0.5) def call(self, x, hidden, enc_output, context_vector): # enc_output shape == (batch_size, max_length, hidden_size) diff --git a/seq2seq_tf2/main.py b/seq2seq_tf2/run_summarization.py similarity index 81% rename from seq2seq_tf2/main.py rename to seq2seq_tf2/run_summarization.py index b88dd24..8d4fb0b 100644 --- a/seq2seq_tf2/main.py +++ b/seq2seq_tf2/run_summarization.py @@ -12,16 +12,23 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--max_enc_len", default=400, help="Encoder input max sequence length", type=int) parser.add_argument("--max_dec_len", default=100, help="Decoder input max sequence length", type=int) + parser.add_argument("--max_dec_steps", default=120, help="maximum number of words of the predicted abstract", type=int) + parser.add_argument("--min_dec_steps", default=30, help="Minimum number of words of the predicted abstract", type=int) parser.add_argument("--batch_size", default=16, help="batch size", type=int) + parser.add_argument("--beam_size", default=3, + help="beam size for beam search decoding (must be equal to batch size in decode mode)", + type=int) parser.add_argument("--vocab_size", default=50000, help="Vocabulary size", type=int) parser.add_argument("--embed_size", default=256, help="Words embeddings dimension", type=int) parser.add_argument("--enc_units", default=256, help="Encoder GRU cell units number", type=int) parser.add_argument("--dec_units", default=256, help="Decoder GRU cell units number", type=int) parser.add_argument("--attn_units", default=512, help="[context vector, decoder state, decoder input] feedforward result dimension - this result is used to compute the attention weights", type=int) - parser.add_argument("--learning_rate", default=0.015, help="Learning rate", type=float) + parser.add_argument("--learning_rate", default=0.15, help="Learning rate", type=float) parser.add_argument("--adagrad_init_acc", default=0.1, help="Adagrad optimizer initial accumulator value. Please refer to the Adagrad optimizer API documentation on tensorflow site for more details.", type=float) parser.add_argument("--max_grad_norm", default=0.8, help="Gradient norm above which gradients must be clipped", type=float) parser.add_argument("--checkpoints_save_steps", default=10, help="Save checkpoints every N steps", type=int) + parser.add_argument("--max_steps", default=10000, help="Max number of iterations", type=int) + parser.add_argument("--num_to_test", default=5, help="Number of examples to test", type=int) parser.add_argument("--mode", default='train', help="training, eval or test options") parser.add_argument("--pointer_gen", default=False, help="training, eval or test options") @@ -48,11 +55,6 @@ def main(): # assert os.path.exists(params["data_dir"]), "data_dir doesn't exist" # assert os.path.isfile(params["vocab_path"]), "vocab_path doesn't exist" - if not os.path.exists("{}".format(params["model_dir"])): - os.makedirs("{}".format(params["model_dir"])) - """i = len([name for name in os.listdir("{}/{}".format(params["model_dir"], "logdir")) if os.path.isfile(name)]) - params["log_file"] = "{}/logdir/tensorflow_{}.log".format(params["model_dir"],i)""" - if params["mode"] == "train": train(params) diff --git a/seq2seq_tf2/seq2seq_model.py b/seq2seq_tf2/seq2seq_model.py index 65100b9..f8c1a62 100644 --- a/seq2seq_tf2/seq2seq_model.py +++ b/seq2seq_tf2/seq2seq_model.py @@ -6,12 +6,11 @@ class PGN(tf.keras.Model): def __init__(self, params): super(PGN, self).__init__() - # self.embedding_matrix = load_word2vec(params["vocab_size"]) - # print() + self.embedding_matrix = load_word2vec(params["vocab_size"]) self.params = params - self.encoder = Encoder(params["vocab_size"], params["embed_size"], params["enc_units"], params["batch_size"]) + self.encoder = Encoder(params["vocab_size"], params["embed_size"], params["enc_units"], params["batch_size"], self.embedding_matrix) self.attention = BahdanauAttention(params["attn_units"]) - self.decoder = Decoder(params["vocab_size"], params["embed_size"], params["dec_units"], params["batch_size"]) + self.decoder = Decoder(params["vocab_size"], params["embed_size"], params["dec_units"], params["batch_size"], self.embedding_matrix) self.pointer = Pointer() def call_encoder(self, enc_inp): @@ -31,12 +30,10 @@ def call_decoder_onestep(self, latest_tokens, enc_hidden, dec_hidden): context_vector) return dec_x, pred, dec_hidden - def call(self, enc_output, dec_hidden, enc_inp, enc_extended_inp, dec_inp, batch_oov_len): predictions = [] attentions = [] p_gens = [] - context_vector, _ = self.attention(dec_hidden, enc_output) if self.params["pointer_gen"]: for t in range(dec_inp.shape[1]): @@ -52,12 +49,17 @@ def call(self, enc_output, dec_hidden, enc_inp, enc_extended_inp, dec_inp, batch final_dists = _calc_final_dist(enc_extended_inp, predictions, attentions, p_gens, batch_oov_len, self.params["vocab_size"], self.params["batch_size"]) - return tf.stack(final_dists, 1), dec_hidden + + if self.params["mode"] == "train": + return tf.stack(final_dists, 1), dec_hidden # predictions_shape = (batch_size, dec_len, vocab_size) with dec_len = 1 in pred mode + else: + return tf.stack(final_dists, 1), dec_hidden, context_vector, tf.stack(attentions, 1), tf.stack(p_gens, 1) else: print('dec_inp is ', dec_inp) print('dec_inp.shape[1] is ', dec_inp.shape[1]) for t in range(dec_inp.shape[1]): + context_vector, _ = self.attention(dec_hidden, enc_output) dec_x, pred, dec_hidden = self.decoder(tf.expand_dims(dec_inp[:, t], 1), dec_hidden, enc_output, diff --git a/seq2seq_tf2/test.py b/seq2seq_tf2/test.py index f14037f..e0e692a 100644 --- a/seq2seq_tf2/test.py +++ b/seq2seq_tf2/test.py @@ -1,11 +1,13 @@ import tensorflow as tf from seq2seq_tf2.seq2seq_model import PGN -from seq2seq_tf2.batcher import Vocab, START_DECODING, STOP_DECODING, article_to_ids, output_to_words, SENTENCE_END +from seq2seq_tf2.batcher import Vocab, START_DECODING, STOP_DECODING, article_to_ids, output_to_words, SENTENCE_END, batcher from seq2seq_tf2.preprocess import preprocess_sentence +from seq2seq_tf2.test_helper import beam_decode +from tqdm import tqdm from seq2seq_tf2 import config import json -params = {'max_enc_len': 400, +_params = {'max_enc_len': 400, 'max_dec_len': 100, 'batch_size': 1, 'vocab_size': 50000, @@ -31,60 +33,99 @@ 'log_file': ''} -def test(sentence): - vocab = Vocab(params["vocab_path"], params["vocab_size"]) - model = PGN(params) - - ckpt = tf.train.Checkpoint(model=model) - checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) - latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir) - ckpt.restore(latest_ckpt).expect_partial() - - sentence = preprocess_sentence(sentence) - print('sentence is ', sentence) - sentence_words = sentence.split()[:params["max_enc_len"]] - print('sentence_words is ', sentence_words) - enc_input = [vocab.word_to_id(w) for w in sentence_words] - print('enc_input is ', enc_input) - enc_input_extend_vocab, article_oovs = article_to_ids(sentence_words, vocab) - print('enc_input_extend_vocab is ', enc_input_extend_vocab) - print('article_oovs', article_oovs) - - start_decoding = vocab.word_to_id(START_DECODING) - stop_decoding = vocab.word_to_id(STOP_DECODING) - - enc_input = tf.keras.preprocessing.sequence.pad_sequences([enc_input], - maxlen=params["max_enc_len"], - padding='post') - print('enc_input is ', enc_input) - enc_input = tf.convert_to_tensor(enc_input) - print('enc_input is ', enc_input) +def test(params): + assert params["mode"].lower() == "test", "change training mode to 'test' or 'eval'" + assert params["beam_size"] == params["batch_size"], "Beam size must be equal to batch_size, change the params" - enc_hidden, enc_output = model.call_encoder(enc_input) - print('enc_hidden is ', enc_hidden) - print('enc_output is ', enc_output) - dec_hidden = enc_hidden - dec_input = tf.expand_dims([start_decoding], 0) - print('dec_input is ', dec_input) - - result = '' - while dec_input != vocab.word_to_id(STOP_DECODING): - _, predictions, dec_hidden = model.call_decoder_onestep(dec_input, enc_output, dec_hidden) - print('predictions is ', predictions) - - predicted_id = tf.argmax(predictions[0]).numpy() - print('predicted_id', predicted_id) - result += vocab.id_to_word(predicted_id) + ' ' - - if vocab.id_to_word(predicted_id) == SENTENCE_END \ - or len(result.split()) >= params['max_dec_len']: - print('Early stopping') - break + tf.compat.v1.logging.info("Building the model ...") + model = PGN(params) - dec_input = tf.expand_dims([predicted_id], 1) - print('dec_input:', dec_input) + print("Creating the vocab ...") + vocab = Vocab(params["vocab_path"], params["vocab_size"]) - print('result: ', result) + print("Creating the batcher ...") + b = batcher(params["data_dir"], vocab, params) + + print("Creating the checkpoint manager") + checkpoint_dir = "{}".format(params["checkpoint_dir"]) + ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) + ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11) + + path = params["model_path"] if params["model_path"] else ckpt_manager.latest_checkpoint + ckpt.restore(path) + print("Model restored") + + for batch in b: + yield beam_decode(model, batch, vocab, params) + + +def test_and_save(params): + assert params["test_save_dir"], "provide a dir where to save the results" + gen = test(params) + with tqdm(total=params["num_to_test"], position=0, leave=True) as pbar: + for i in range(params["num_to_test"]): + trial = next(gen) + with open(params["test_save_dir"] + "/article_" + str(i) + ".txt", "w") as f: + f.write("article:\n") + f.write(trial.text) + f.write("\n\nabstract:\n") + f.write(trial.abstract) + pbar.update(1) + +# def _test(sentence): +# vocab = Vocab(params["vocab_path"], params["vocab_size"]) +# model = PGN(params) +# +# ckpt = tf.train.Checkpoint(model=model) +# checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) +# latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir) +# ckpt.restore(latest_ckpt).expect_partial() +# +# sentence = preprocess_sentence(sentence) +# print('sentence is ', sentence) +# sentence_words = sentence.split()[:params["max_enc_len"]] +# print('sentence_words is ', sentence_words) +# enc_input = [vocab.word_to_id(w) for w in sentence_words] +# print('enc_input is ', enc_input) +# enc_input_extend_vocab, article_oovs = article_to_ids(sentence_words, vocab) +# print('enc_input_extend_vocab is ', enc_input_extend_vocab) +# print('article_oovs', article_oovs) +# +# start_decoding = vocab.word_to_id(START_DECODING) +# stop_decoding = vocab.word_to_id(STOP_DECODING) +# +# enc_input = tf.keras.preprocessing.sequence.pad_sequences([enc_input], +# maxlen=params["max_enc_len"], +# padding='post') +# print('enc_input is ', enc_input) +# enc_input = tf.convert_to_tensor(enc_input) +# print('enc_input is ', enc_input) +# +# enc_hidden, enc_output = model.call_encoder(enc_input) +# print('enc_hidden is ', enc_hidden) +# print('enc_output is ', enc_output) +# dec_hidden = enc_hidden +# dec_input = tf.expand_dims([start_decoding], 0) +# print('dec_input is ', dec_input) +# +# result = '' +# while dec_input != vocab.word_to_id(STOP_DECODING): +# _, predictions, dec_hidden = model.call_decoder_onestep(dec_input, enc_output, dec_hidden) +# print('predictions is ', predictions) +# +# predicted_id = tf.argmax(predictions[0]).numpy() +# print('predicted_id', predicted_id) +# result += vocab.id_to_word(predicted_id) + ' ' +# +# if vocab.id_to_word(predicted_id) == SENTENCE_END \ +# or len(result.split()) >= params['max_dec_len']: +# print('Early stopping') +# break +# +# dec_input = tf.expand_dims([predicted_id], 1) +# print('dec_input:', dec_input) +# +# print('result: ', result) if __name__ == '__main__': diff --git a/seq2seq_tf2/test_helper.py b/seq2seq_tf2/test_helper.py new file mode 100644 index 0000000..9e1124e --- /dev/null +++ b/seq2seq_tf2/test_helper.py @@ -0,0 +1,145 @@ +import tensorflow as tf +import numpy as np +from seq2seq_tf2.batcher import output_to_words + + +def beam_decode(model, batch, vocab, params): + + def decode_onestep(batch, enc_outputs, dec_state, dec_input): + """ + Method to decode the output step by step (used for beamSearch decoding) + Args: + sess : tf.Session object + batch : current batch, shape = [beam_size, 1, vocab_size( + max_oov_len if pointer_gen)] (for the beam search decoding, batch_size = beam_size) + enc_outputs : hiddens outputs computed by the encoder LSTM + dec_state : beam_size-many list of decoder previous state, LSTMStateTuple objects, shape = [beam_size, 2, hidden_size] + dec_input : decoder_input, the previous decoded batch_size-many words, shape = [beam_size, embed_size] + cov_vec : beam_size-many list of previous coverage vector + Returns: A dictionary of the results of all the ops computations (see below for more details) + """ + # dictionary of all the ops that will be computed + final_dists, dec_hidden, context_vector, attentions, p_gens = model(enc_outputs, dec_state, + batch[0]["enc_input"], + batch[0]["extended_enc_input"], dec_input, + batch[0]["max_oov_len"]) + top_k_probs, top_k_ids = tf.nn.top_k(tf.squeeze(final_dists), k=params["beam_size"] * 2) + top_k_log_probs = tf.math.log(top_k_probs) + results = {"last_context_vector": context_vector, + "dec_state": dec_hidden, + "attention_vec": attentions, + "top_k_ids": top_k_ids, + "top_k_log_probs": top_k_log_probs, + "p_gen": p_gens} + + return results + + # nested class + + class Hypothesis: + """ Class designed to hold hypothesises throughout the beamSearch decoding """ + + def __init__(self, tokens, log_probs, state, attn_dists, p_gens): + self.tokens = tokens # list of all the tokens from time 0 to the current time step t + self.log_probs = log_probs # list of the log probabilities of the tokens of the tokens + self.state = state # decoder state after the last token decoding + self.attn_dists = attn_dists # attention dists of all the tokens + self.p_gens = p_gens # generation probability of all the tokens + self.abstract = "" + self.text = "" + self.real_abstract = "" + + def extend(self, token, log_prob, state, attn_dist, p_gen): + """Method to extend the current hypothesis by adding the next decoded token and all the informations associated with it""" + return Hypothesis(tokens=self.tokens + [token], # we add the decoded token + log_probs=self.log_probs + [log_prob], # we add the log prob of the decoded token + state=state, # we update the state + attn_dists=self.attn_dists + [attn_dist], + # we add the attention dist of the decoded token + p_gens=self.p_gens + [p_gen] # we add the p_gen + ) + + @property + def latest_token(self): + return self.tokens[-1] + + @property + def tot_log_prob(self): + return sum(self.log_probs) + + @property + def avg_log_prob(self): + return self.tot_log_prob / len(self.tokens) + + # end of the nested class + + # We run the encoder once and then we use the results to decode each time step token + + state, enc_outputs = model.call_encoder(batch[0]["enc_input"]) + + # Initial Hypothesises (beam_size many list) + hyps = [Hypothesis(tokens=[vocab.word_to_id('[START]')], + # we initalize all the beam_size hypothesises with the token start + log_probs=[0.0], # Initial log prob = 0 + state=state[0], + # initial dec_state (we will use only the first dec_state because they're initially the same) + attn_dists=[], + p_gens=[], # we init the coverage vector to zero + ) for _ in range(params['batch_size'])] # batch_size == beam_size + + results = [] # list to hold the top beam_size hypothesises + steps = 0 # initial step + + while steps < params['max_dec_steps'] and len(results) < params['beam_size']: + latest_tokens = [h.latest_token for h in hyps] # latest token for each hypothesis , shape : [beam_size] + # we replace all the oov is by the unknown token + latest_tokens = [t if t in range(params['vocab_size']) else vocab.word_to_id('[UNK]') for t in latest_tokens] + # we collect the last states for each hypothesis + states = [h.state for h in hyps] + + # we decode the top likely 2 x beam_size tokens tokens at time step t for each hypothesis + returns = decode_onestep(batch, enc_outputs, tf.stack(states, axis=0), tf.expand_dims(latest_tokens, axis=1)) + topk_ids, topk_log_probs, new_states, attn_dists, p_gens = returns['top_k_ids'],\ + returns['top_k_log_probs'], \ + returns['dec_state'],\ + returns['attention_vec'],\ + np.squeeze(returns["p_gen"]) + all_hyps = [] + num_orig_hyps = 1 if steps == 0 else len(hyps) + + for i in range(num_orig_hyps): + h, new_state, attn_dist, p_gen = hyps[i], new_states[i], attn_dists[i], p_gens[i] + + for j in range(params['beam_size'] * 2): + # we extend each hypothesis with each of the top k tokens (this gives 2 x beam_size new hypothesises for each of the beam_size old hypothesises) + new_hyp = h.extend(token=topk_ids[i, j].numpy(), + log_prob=topk_log_probs[i, j], + state=new_state, + attn_dist=attn_dist, + p_gen=p_gen) + all_hyps.append(new_hyp) + + # in the following lines, we sort all the hypothesises, and select only the beam_size most likely hypothesises + hyps = [] + sorted_hyps = sorted(all_hyps, key=lambda h: h.avg_log_prob, reverse=True) + for h in sorted_hyps: + if h.latest_token == vocab.word_to_id('[STOP]'): + if steps >= params['min_dec_steps']: + results.append(h) + else: + hyps.append(h) + if len(hyps) == params['beam_size'] or len(results) == params['beam_size']: + break + + steps += 1 + + if len(results) == 0: + results = hyps + + # At the end of the loop we return the most likely hypothesis, which holds the most likely ouput sequence, given the input fed to the model + hyps_sorted = sorted(results, key=lambda h: h.avg_log_prob, reverse=True) + best_hyp = hyps_sorted[0] + best_hyp.abstract = " ".join(output_to_words(best_hyp.tokens, vocab, batch[0]["article_oovs"][0])[1:-1]) + best_hyp.text = batch[0]["article"].numpy()[0].decode() + # if params["mode"] == "eval": + # best_hyp.real_abstract = batch[1]["abstract"].numpy()[0].decode() + return best_hyp diff --git a/seq2seq_tf2/train.py b/seq2seq_tf2/train.py index 83f7e5a..1d72a42 100644 --- a/seq2seq_tf2/train.py +++ b/seq2seq_tf2/train.py @@ -5,21 +5,23 @@ from seq2seq_tf2 import config from seq2seq_tf2.seq2seq_model import PGN -from seq2seq_tf2.batcher import batcher +from seq2seq_tf2.batcher import batcher, Vocab from seq2seq_tf2.train_helper import train_model def train(params): assert params["mode"].lower() == "train", "change training mode to 'train'" - tf.compat.v1.logging.info("Building the model ...") + print("Building the model ...") model = PGN(params) - tf.compat.v1.logging.info("Creating the batcher ...") - b = batcher(params["train_seg_x_dir"], params["train_seg_y_dir"], params["vocab_path"], params) + print("Creating the vocab ...") + vocab = Vocab(params["vocab_path"], params["vocab_size"]) - tf.compat.v1.logging.info("Creating the checkpoint manager") - logdir = "{}/logdir".format(params["model_dir"]) + print("Creating the batcher ...") + b = batcher(params["train_seg_x_dir"], params["train_seg_y_dir"], vocab, params) + + print("Creating the checkpoint manager") checkpoint_dir = "{}/checkpoint".format(params["model_dir"]) ckpt = tf.train.Checkpoint(step=tf.Variable(0), PGN=model) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=11) @@ -30,7 +32,7 @@ def train(params): else: print("Initializing from scratch.") - tf.compat.v1.logging.info("Starting the training ...") + print("Starting the training ...") train_model(model, b, params, ckpt, ckpt_manager) diff --git a/seq2seq_tf2/train_helper.py b/seq2seq_tf2/train_helper.py index 9145895..563751a 100644 --- a/seq2seq_tf2/train_helper.py +++ b/seq2seq_tf2/train_helper.py @@ -26,7 +26,6 @@ def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len): loss = 0 with tf.GradientTape() as tape: - print('enc_inp is ', enc_inp) enc_hidden, enc_output = model.call_encoder(enc_inp) predictions, _ = model(enc_output, enc_hidden, enc_inp, enc_extended_inp, dec_inp, batch_oov_len) loss = loss_function(dec_tar, predictions) @@ -38,7 +37,7 @@ def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len): try: for batch in dataset: - print("batch is {}".format(batch)) + # print("batch is {}".format(batch)) t0 = time.time() print('batch[0]["enc_input"] is ', batch[0]["enc_input"]) loss = train_step(batch[0]["enc_input"], batch[0]["extended_enc_input"], batch[1]["dec_input"], @@ -46,10 +45,15 @@ def train_step(enc_inp, enc_extended_inp, dec_inp, dec_tar, batch_oov_len): print('Step {}, time {:.4f}, Loss {:.4f}'.format(int(ckpt.step), time.time() - t0, loss.numpy())) + if int(ckpt.step) == params["max_steps"]: + ckpt_manager.save(checkpoint_number=int(ckpt.step)) + print("Saved checkpoint for step {}".format(int(ckpt.step))) + break if int(ckpt.step) % params["checkpoints_save_steps"] == 0: ckpt_manager.save(checkpoint_number=int(ckpt.step)) print("Saved checkpoint for step {}".format(int(ckpt.step))) ckpt.step.assign_add(1) + except KeyboardInterrupt: ckpt_manager.save(int(ckpt.step)) print("Saved checkpoint for step {}".format(int(ckpt.step)))