Skip to content

Commit

Permalink
delete pre word2vec
Browse files Browse the repository at this point in the history
  • Loading branch information
zn-nlp committed Nov 6, 2019
1 parent 29590f8 commit 2d520c1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
15 changes: 7 additions & 8 deletions seq2seq_tf2/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@


class Encoder(tf.keras.layers.Layer):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz, embedding_matrix):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
weights=[embedding_matrix],
trainable=False)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
self.gru = tf.keras.layers.CuDNNGRU(self.enc_units,
Expand Down Expand Up @@ -65,13 +63,14 @@ def call(self, query, values):


class Decoder(tf.keras.layers.Layer):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, embedding_matrix):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
weights=[embedding_matrix],
trainable=False)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
# self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
# weights=[embedding_matrix],
# trainable=False)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
self.gru = tf.keras.layers.CuDNNGRU(self.enc_units,
Expand Down
6 changes: 3 additions & 3 deletions seq2seq_tf2/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
class PGN(tf.keras.Model):
def __init__(self, params):
super(PGN, self).__init__()
self.embedding_matrix = load_word2vec(params["vocab_size"])
# self.embedding_matrix = load_word2vec(params["vocab_size"])
self.params = params
self.encoder = Encoder(params["vocab_size"], params["embed_size"], params["enc_units"], params["batch_size"], self.embedding_matrix)
self.encoder = Encoder(params["vocab_size"], params["embed_size"], params["enc_units"], params["batch_size"])
self.attention = BahdanauAttention(params["attn_units"])
self.decoder = Decoder(params["vocab_size"], params["embed_size"], params["dec_units"], params["batch_size"], self.embedding_matrix)
self.decoder = Decoder(params["vocab_size"], params["embed_size"], params["dec_units"], params["batch_size"])
self.pointer = Pointer()

def call_encoder(self, enc_inp):
Expand Down

0 comments on commit 2d520c1

Please sign in to comment.