Skip to content

Commit

Permalink
add cuddgru
Browse files Browse the repository at this point in the history
  • Loading branch information
zn-nlp committed Nov 4, 2019
1 parent 49e5392 commit d188e95
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions seq2seq_tf2/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz, embedding_mat
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
weights=[embedding_matrix],
trainable=False)
self.gru = tf.keras.layers.GRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
self.gru = tf.keras.layers.CuDNNGRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
else:
self.gru = tf.keras.layers.GRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
# self.bigru = tf.keras.layers.Bidirectional(self.gru, merge_mode='concat')

def call(self, x, hidden):
Expand Down Expand Up @@ -65,10 +72,17 @@ def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz, embedding_mat
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,
weights=[embedding_matrix],
trainable=False)
self.gru = tf.keras.layers.GRU(self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
self.gru = tf.keras.layers.CuDNNGRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
else:
self.gru = tf.keras.layers.GRU(self.enc_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size, activation=tf.keras.activations.softmax)
# self.fc = tf.keras.layers.Dropout(0.5)

Expand Down

0 comments on commit d188e95

Please sign in to comment.