@@ -67,8 +67,8 @@ def textgenrnn_generate(model, vocab,
6767 if not isinstance (temperature , list ):
6868 temperature = [temperature ]
6969
70- if model_input_count (model ) > 1 :
71- model = Model (inputs = model .input [0 ], outputs = model .output [1 ])
70+ if len (model . inputs ) > 1 :
71+ model = Model (inputs = model .inputs [0 ], outputs = model .outputs [1 ])
7272
7373 while next_char != meta_token and len (text ) < max_gen_length :
7474 encoded_text = textgenrnn_encode_sequence (text [- maxlen :],
@@ -166,13 +166,6 @@ def textgenrnn_encode_cat(chars, vocab):
166166 return a
167167
168168
169- def model_input_count (model ):
170- if isinstance (model .input , list ):
171- return len (model .input )
172- else : # is a Tensor
173- return model .input .shape [0 ]
174-
175-
176169class generate_after_epoch (Callback ):
177170 def __init__ (self , textgenrnn , gen_epochs , max_gen_length ):
178171 self .textgenrnn = textgenrnn
@@ -192,7 +185,7 @@ def __init__(self, weights_name, num_epochs, save_epochs):
192185 self .save_epochs = save_epochs
193186
194187 def on_epoch_end (self , epoch , logs = {}):
195- if model_input_count (self .model ) > 1 :
188+ if len (self .model . inputs ) > 1 :
196189 self .model = Model (inputs = self .model .input [0 ],
197190 outputs = self .model .output [1 ])
198191 if self .save_epochs > 0 and (epoch + 1 ) % self .save_epochs == 0 and self .num_epochs != (epoch + 1 ):
0 commit comments