From ea925188177379e9e57f267224ce402ae503da01 Mon Sep 17 00:00:00 2001 From: Aaditya GPGPU Date: Wed, 16 Nov 2016 00:28:36 -0500 Subject: [PATCH] fix the model restore path error --- params.py | 1 + train.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/params.py b/params.py index d66e1b9..83a1d6a 100644 --- a/params.py +++ b/params.py @@ -84,6 +84,7 @@ def __init__(self, verbose): self.data_train_path = './data/train.pickle' self.data_test_path = './data/test.pickle' self.resume_training = True + self.on_resume_fix_lr = True if verbose: pprint(self.__dict__) diff --git a/train.py b/train.py index 1506366..a7bee2a 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ from __future__ import division +from __future__ import print_function import math import os import tensorflow as tf @@ -55,7 +56,10 @@ def sparse_labels_or_not(batch): sess.run(tf.initialize_all_variables()) if tparam.resume_training: - saver.restore(sess, tparam.model_path + '/model') + saver.restore(sess, tparam.model_path + 'model') + if tparam.on_resume_fix_lr: + optimizer = tf.train.AdamOptimizer(tparam.learning_rate) + print("model restored...") # for the pretty pretty tensorboard summary_writer = tf.train.SummaryWriter('tensorboards', sess.graph)