Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
edvardHua committed Mar 28, 2019
2 parents b1fb250 + e48b2aa commit 5a84e5f
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions training/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
from dataset_prepare import CocoPose
from dataset_augment import set_network_input_wh, set_network_scale

def get_input(batchsize, epoch, is_train=True):
def get_input_iter(batchsize, epoch, is_train=True):
if is_train is True:
input_pipeline = get_train_dataset_pipeline(batch_size=batchsize, epoch=epoch, buffer_size=100)
else:
input_pipeline = get_valid_dataset_pipeline(batch_size=batchsize, epoch=epoch, buffer_size=100)
iter = input_pipeline.make_one_shot_iterator()
_ = iter.get_next()
return _[0], _[1]

return iter

def get_loss_and_output(model, batchsize, input_image, input_heat, reuse_variables=None):
losses = []
Expand Down Expand Up @@ -116,8 +114,14 @@ def main(argv=None):
)

with tf.Graph().as_default(), tf.device("/cpu:0"):
input_image, input_heat = get_input(params['batchsize'], params['max_epoch'], is_train=True)
valid_input_image, valid_input_heat = get_input(params['batchsize'], params['max_epoch'], is_train=False)
train_dataset = get_train_dataset_pipeline(params['batchsize'], params['max_epoch'], buffer_size=100)
valid_dataset = get_valid_dataset_pipeline(params['batchsize'], params['max_epoch'], buffer_size=100)

train_iterator = train_dataset.make_one_shot_iterator()
valid_iterator = valid_dataset.make_one_shot_iterator()

handle = tf.placeholder(tf.string, shape=[])
input_iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes)

global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(float(params['lr']), global_step,
Expand All @@ -130,33 +134,23 @@ def main(argv=None):
# cpu (mac only)
with tf.device("/cpu:0"):
with tf.name_scope("CPU_0"):
input_image, input_heat = input_iterator.get_next()
loss, last_heat_loss, pred_heat = get_loss_and_output(params['model'], params['batchsize'],
input_image, input_heat, reuse_variable)
reuse_variable = True
grads = opt.compute_gradients(loss)
tower_grads.append(grads)

valid_loss, valid_last_heat_loss, valid_pred_heat = get_loss_and_output(params['model'],
params['batchsize'],
valid_input_image,
valid_input_heat,
reuse_variable)
else:
# multiple gpus
for i in range(params['gpus']):
with tf.device("/gpu:%d" % i):
with tf.name_scope("GPU_%d" % i):
input_image, input_heat = input_iterator.get_next()
loss, last_heat_loss, pred_heat = get_loss_and_output(params['model'], params['batchsize'], input_image, input_heat, reuse_variable)
reuse_variable = True
grads = opt.compute_gradients(loss)
tower_grads.append(grads)

valid_loss, valid_last_heat_loss, valid_pred_heat = get_loss_and_output(params['model'],
params['batchsize'],
valid_input_image,
valid_input_heat,
reuse_variable)

grads = average_gradients(tower_grads)
for grad, var in grads:
if grad is not None:
Expand Down Expand Up @@ -191,7 +185,8 @@ def main(argv=None):
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init.run()

train_handle = sess.run(train_iterator.string_handle())
valid_handle = sess.run(valid_iterator.string_handle())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

Expand All @@ -200,19 +195,18 @@ def main(argv=None):
print("Start training...")
for step in range(total_step_num):
start_time = time.time()
_, loss_value, lh_loss, in_image, in_heat, p_heat = sess.run(
[train_op, loss, last_heat_loss, input_image, input_heat, pred_heat]
_, loss_value, lh_loss = sess.run([train_op, loss, last_heat_loss],
feed_dict={handle: train_handle}
)
duration = time.time() - start_time

if step != 0 and step % params['per_update_tensorboard_step'] == 0:
# False will speed up the training time.
if params['pred_image_on_tensorboard'] is True:

valid_loss_value, valid_lh_loss, valid_in_image, valid_in_heat, valid_p_heat = sess.run(
[valid_loss, valid_last_heat_loss, valid_input_image, valid_input_heat, valid_pred_heat]
[loss, last_heat_loss, input_image, input_heat, pred_heat],
feed_dict={handle: valid_handle}
)

result = []
for index in range(params['batchsize']):
r = CocoPose.display_image(
Expand Down Expand Up @@ -241,11 +235,11 @@ def main(argv=None):
print(format_str % (datetime.now(), step, loss_value, lh_loss, examples_per_sec, sec_per_batch))

# tensorboard visualization
merge_op = sess.run(summary_merge_op)
merge_op = sess.run(summary_merge_op, feed_dict={handle: valid_handle})
summary_writer.add_summary(merge_op, step)

# save model
if step % params['per_saved_model_step'] == 0:
if step != 0 and step % params['per_saved_model_step'] == 0:
checkpoint_path = os.path.join(params['modelpath'], training_name, 'model')
saver.save(sess, checkpoint_path, global_step=step)
coord.request_stop()
Expand Down

0 comments on commit 5a84e5f

Please sign in to comment.