Skip to content

Commit 7850b91

Browse files
authored
Merge pull request #44 from torokati44/master
Allow using a backend other than TensorFlow
2 parents f751cdb + f150989 commit 7850b91

File tree

3 files changed

+7
-14
lines changed

3 files changed

+7
-14
lines changed

textgenrnn/AttentionWeightedAverage.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from keras.engine import InputSpec, Layer
22
from keras import backend as K
33
from keras import initializers
4-
import tensorflow as tf
54

65

76
class AttentionWeightedAverage(Layer):

textgenrnn/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from keras.optimizers import RMSprop
22
from keras.layers import Input, Embedding, Dense, LSTM, Bidirectional
3-
from keras.layers import CuDNNLSTM, concatenate, Reshape, SpatialDropout1D
3+
from keras.layers import concatenate, Reshape, SpatialDropout1D
44
from keras.models import Model
55
from keras import backend as K
66
from .AttentionWeightedAverage import AttentionWeightedAverage
@@ -68,8 +68,9 @@ def textgenrnn_model(num_classes, cfg, context_size=None,
6868

6969

7070
def new_rnn(cfg, layer_num):
71-
has_gpu = len(K.tensorflow_backend._get_available_gpus()) > 0
72-
if has_gpu:
71+
use_cudnnlstm = K.backend() == 'tensorflow' and len(K.tensorflow_backend._get_available_gpus()) > 0
72+
if use_cudnnlstm:
73+
from keras.layers import CuDNNLSTM
7374
if cfg['rnn_bidirectional']:
7475
return Bidirectional(CuDNNLSTM(cfg['rnn_size'],
7576
return_sequences=True),

textgenrnn/utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def textgenrnn_generate(model, vocab,
6767
if not isinstance(temperature, list):
6868
temperature = [temperature]
6969

70-
if model_input_count(model) > 1:
71-
model = Model(inputs=model.input[0], outputs=model.output[1])
70+
if len(model.inputs) > 1:
71+
model = Model(inputs=model.inputs[0], outputs=model.outputs[1])
7272

7373
while next_char != meta_token and len(text) < max_gen_length:
7474
encoded_text = textgenrnn_encode_sequence(text[-maxlen:],
@@ -166,13 +166,6 @@ def textgenrnn_encode_cat(chars, vocab):
166166
return a
167167

168168

169-
def model_input_count(model):
170-
if isinstance(model.input, list):
171-
return len(model.input)
172-
else: # is a Tensor
173-
return model.input.shape[0]
174-
175-
176169
class generate_after_epoch(Callback):
177170
def __init__(self, textgenrnn, gen_epochs, max_gen_length):
178171
self.textgenrnn = textgenrnn
@@ -192,7 +185,7 @@ def __init__(self, weights_name, num_epochs, save_epochs):
192185
self.save_epochs = save_epochs
193186

194187
def on_epoch_end(self, epoch, logs={}):
195-
if model_input_count(self.model) > 1:
188+
if len(self.model.inputs) > 1:
196189
self.model = Model(inputs=self.model.input[0],
197190
outputs=self.model.output[1])
198191
if self.save_epochs > 0 and (epoch+1) % self.save_epochs == 0 and self.num_epochs != (epoch+1):

0 commit comments

Comments
 (0)