Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions osvos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from PIL import Image

slim = tf.contrib.slim

mean = np.array((104.00699, 116.66877, 122.67892), dtype=np.float32)

def osvos_arg_scope(weight_decay=0.0002):
"""Defines the OSVOS arg scope.
Expand Down Expand Up @@ -160,6 +160,20 @@ def interp_surgery(variables):
return interp_tensors


def preprocess_img_tf(image):
"""Preprocess the image to adapt it to network requirements
Args:
Image we want to input the network (W,H,3) numpy array
Returns:
Image ready to input the network (1,W,H,3)
"""
img = tf.cast(image, tf.float32)
img = tf.reverse(img, axis=[-1])
img = tf.subtract(img, mean)
img = tf.expand_dims(img, 0)
return img


# TO DO: Move preprocessing into Tensorflow
def preprocess_img(image):
"""Preprocess the image to adapt it to network requirements
Expand Down Expand Up @@ -544,7 +558,7 @@ def _train(dataset, initial_ckpt, supervison, learning_rate, logs_path, max_trai
# Average the gradient
for _ in range(0, iter_mean_grad):
batch_image, batch_label = dataset.next_batch(batch_size, 'train')
image = preprocess_img(batch_image[0])
image = preprocess_img_tf(batch_image[0]).eval()
label = preprocess_labels(batch_label[0])
run_res = sess.run([total_loss, merged_summary_op] + grad_accumulator_ops,
feed_dict={input_image: image, input_label: label})
Expand All @@ -564,7 +578,7 @@ def _train(dataset, initial_ckpt, supervison, learning_rate, logs_path, max_trai
# Save a checkpoint
if step % save_step == 0:
if test_image_path is not None:
curr_output = sess.run(img_summary, feed_dict={input_image: preprocess_img(test_image_path)})
curr_output = sess.run(img_summary, feed_dict={input_image: preprocess_img_tf(test_image_path)})
summary_writer.add_summary(curr_output, step)
save_path = saver.save(sess, model_name, global_step=global_step)
print "Model saved in file: %s" % save_path
Expand Down Expand Up @@ -644,7 +658,7 @@ def test(dataset, checkpoint_file, result_path, config=None):
for frame in range(0, dataset.get_test_size()):
img, curr_img = dataset.next_batch(batch_size, 'test')
curr_frame = curr_img[0].split('/')[-1].split('.')[0] + '.png'
image = preprocess_img(img[0])
image = preprocess_img_tf(img[0]).eval()
res = sess.run(probabilities, feed_dict={input_image: image})
res_np = res.astype(np.float32)[0, :, :, 0] > 162.0/255.0
scipy.misc.imsave(os.path.join(result_path, curr_frame), res_np.astype(np.float32))
Expand Down