-
Notifications
You must be signed in to change notification settings - Fork 0
/
osvos_train_test.py
50 lines (45 loc) · 2.07 KB
/
osvos_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import print_function
import os
import sys
from PIL import Image
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
import matplotlib.pyplot as plt
# Import OSVOS files
root_folder = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(root_folder))
import osvos
from dataset import Dataset
os.chdir(root_folder)
def train_and_test_osvos(seq_name, gpu_id, result_path, train_model, max_train_iters, train_img_name, annot_img_name):
# Train parameters
parent_path = os.path.join('models', 'OSVOS_parent', 'OSVOS_parent.ckpt-50000')
logs_path = os.path.join('models', seq_name)
max_training_iters = max_train_iters
# Define Dataset
test_frames = sorted(os.listdir(os.path.join('DAVIS', 'JPEGImages', '480p', seq_name)))
test_imgs = [os.path.join('DAVIS', 'JPEGImages', '480p', seq_name, frame) for frame in test_frames]
if train_model:
train_imgs = [os.path.join('DAVIS', 'JPEGImages', '480p', seq_name, train_img_name)+' '+
os.path.join('DAVIS', 'Annotations', '480p', seq_name, annot_img_name)]
dataset = Dataset(train_imgs, test_imgs, './', data_aug=True)
else:
dataset = Dataset(None, test_imgs, './')
# Train the network
if train_model:
# More training parameters
learning_rate = 1e-8
save_step = max_training_iters
side_supervision = 3
display_step = 10
with tf.Graph().as_default():
with tf.device('/gpu:' + str(gpu_id)):
global_step = tf.Variable(0, name='global_step', trainable=False)
osvos.train_finetune(dataset, parent_path, side_supervision, learning_rate, logs_path, max_training_iters,
save_step, display_step, global_step, iter_mean_grad=1, ckpt_name=seq_name)
# Test the network
with tf.Graph().as_default():
with tf.device('/gpu:' + str(gpu_id)):
checkpoint_path = os.path.join('models', seq_name, seq_name+'.ckpt-'+str(max_training_iters))
osvos.test(dataset, checkpoint_path, result_path)