diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 2a1dd300d..b49160219 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -20,8 +20,6 @@ def interact_model( if batch_size is None: batch_size = 1 assert nsamples % batch_size == 0 - np.random.seed(seed) - tf.set_random_seed(seed) enc = encoder.get_encoder(model_name) hparams = model.default_hparams() @@ -35,6 +33,8 @@ def interact_model( with tf.Session(graph=tf.Graph()) as sess: context = tf.placeholder(tf.int32, [batch_size, None]) + np.random.seed(seed) + tf.set_random_seed(seed) output = sample.sample_sequence( hparams=hparams, length=length, context=context,