Skip to content

Commit

Permalink
change the validation loss parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
iamaaditya committed Nov 16, 2016
1 parent c281eb8 commit 5fbd91b
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
data_test = pd.read_pickle(tparam.data_test_path)
len_train = len(data_train)
len_test = len(data_train)
num_batches = int(math.ceil(len_train/tparam.batch_size))
train_b_num = int(math.ceil(len_train/tparam.batch_size))
test_b_num = int(math.ceil(len_train/tparam.batch_size))
images_tf = tf.placeholder(tf.float32, [None, hyper.image_h, hyper.image_w, hyper.image_c], name = "images")
if hyper.sparse:
labels_tf = tf.placeholder(tf.int64, [None], name = 'labels')
Expand All @@ -35,15 +36,27 @@
loss_tf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(prob_tf, labels_tf))
else:
loss_tf = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prob_tf, labels_tf))
loss_tf_write = tf.scalar_summary("training loss", loss_tf)
train_loss = tf.scalar_summary("training loss", loss_tf)
test_loss = tf.scalar_summary("validation loss", loss_tf)
optimizer = tf.train.AdamOptimizer(tparam.learning_rate)
train_op = optimizer.minimize(loss_tf)

def sparse_labels_or_not(batch):
if hyper.sparse:
return batch['label'].values
else:
labels = np.zeros((len(batch), hyper.n_labels))
for i,j in enumerate(batch['label'].values):
labels[i,j] = 1
return labels

with tf.Session() as sess:
saver = tf.train.Saver()
sess.run(tf.initialize_all_variables())

if tparam.resume_training:
saver.restore(sess, tparam.model_path + '/model')

# for the pretty pretty tensorboard
summary_writer = tf.train.SummaryWriter('tensorboards', sess.graph)

Expand All @@ -53,28 +66,24 @@
epoch_loss = 0
for b, train_batch in enumerate(chunker(data_train.sample(frac=1),tparam.batch_size)):
train_images = np.array(map(lambda i: load_image(i), train_batch['image_path'].values))
if hyper.sparse:
train_labels = train_batch['label'].values
else:
train_labels = np.zeros((len(train_batch), hyper.n_labels))
for i,j in enumerate(train_batch['label'].values):
train_labels[i,j] = 1
_, batch_loss, loss_sw = sess.run([train_op, loss_tf, loss_tf_write], feed_dict={images_tf: train_images, labels_tf: train_labels})
train_labels = sparse_labels_or_not(train_batch)
_, batch_loss, loss_sw = sess.run([train_op, loss_tf, train_loss], feed_dict={images_tf: train_images, labels_tf: train_labels})

average_batch_loss = np.average(batch_loss)
epoch_loss += average_batch_loss
summary_writer.add_summary(loss_sw, epoch*220+b)
print("Train: epoch:{}, batch:{}/{}, loss:{}".format(epoch, b, num_batches, average_batch_loss))
print("Train: epoch:{}, total loss:{}".format(epoch, epoch_loss/num_batches))
summary_writer.add_summary(loss_sw, epoch*train_b_num+b)
print("Train: epoch:{}, batch:{}/{}, loss:{}".format(epoch, b, train_b_num, average_batch_loss))
print("Train: epoch:{}, total loss:{}".format(epoch, epoch_loss/train_b_num))

# Validation
correct_count = 0
for b, test_batch in enumerate(chunker(data_test,tparam.batch_size)):
validation_loss = 0
for b, test_batch in enumerate(chunker(data_test,tparam.batch_size)): # no need to randomize test batch
test_images = np.array(map(lambda i: load_image(i), test_batch['image_path'].values))
test_labels = test_batch['label'].values
probs_val = sess.run(prob_tf, feed_dict={images_tf:test_images})
correct_count += (probs_val.argmax(axis=1) == test_labels).sum()
print("Test: epoch:{}, accuracy:{}".format(epoch, correct_count/len_test))
# don't run the train_op by mistake ! ;-)
test_labels = sparse_labels_or_not(test_batch)
batch_loss,loss_sw = sess.run([loss_tf, test_loss], feed_dict={images_tf: test_images, labels_tf: test_labels})
summary_writer.add_summary(loss_sw, epoch*test_b_num+b)
print("Test: epoch:{}, total loss:{}".format(epoch, validation_loss/b))
print("Time for one epoch:{}".format(time()-start))
# save the model
saver.save(sess, tparam.model_path + '/model')

0 comments on commit 5fbd91b

Please sign in to comment.