forked from MarvinTeichmann/KittiBox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
82 lines (62 loc) · 2.34 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Trains, evaluates and saves the TensorDetect model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
import os
import sys
# configure logging
if 'TV_IS_DEV' in os.environ and os.environ['TV_IS_DEV']:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
else:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
sys.path.insert(1, 'incl')
import tensorvision.train as train
import tensorvision.utils as utils
flags.DEFINE_string('name', None,
'Append a name Tag to run.')
flags.DEFINE_string('project', None,
'Append a name Tag to run.')
flags.DEFINE_string('hypes', 'hypes/kittiBox.json',
'File storing model parameters.')
tf.app.flags.DEFINE_boolean(
'save', True, ('Whether to save the run. In case --nosave (default) '
'output will be saved to the folder TV_DIR_RUNS/debug, '
'hence it will get overwritten by further runs.'))
def main(_):
utils.set_gpus_to_use()
try:
import tensorvision.train
except ImportError:
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
with open(tf.app.flags.FLAGS.hypes, 'r') as f:
logging.info("f: %s", f)
hypes = json.load(f)
utils.load_plugins()
if 'TV_DIR_RUNS' in os.environ:
os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
'KittiBox')
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
utils._add_paths_to_sys(hypes)
logging.info("Initialize training folder")
train.initialize_training_folder(hypes)
train.maybe_download_and_extract(hypes)
logging.info("Start training")
train.do_training(hypes)
if __name__ == '__main__':
tf.app.run()