Skip to content

Commit

Permalink
Update train and test scripts for RPN optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
fschaeffler93 authored and fschaeffler93 committed Jul 9, 2019
1 parent 413384a commit c175dfb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
parser = argparse.ArgumentParser(description='testing')
parser.add_argument('-n', '--tag', type=str, nargs='?', default='pre_trained_car',
help='set log tag')
parser.add_argument('-d', '--decrease', type=bool, nargs='?', default=False,
help='set the flag to True if decrease model')
parser.add_argument('-m', '--minimize', type=bool, nargs='?', default=False,
help='set the flag to True if minimize model')
parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1,
help='set batch size for each gpu')
parser.add_argument('-o', '--output-path', type=str, nargs='?', default='predictions',
Expand Down Expand Up @@ -56,6 +60,8 @@ def main(_):
with tf.Session(config=config) as sess:
model = RPN3D(
cls=cfg.DETECT_OBJ,
decrease=args.decrease,
minimize=args.minimize,
single_batch_size=args.single_batch_size,
avail_gpus=cfg.GPU_AVAILABLE.split(',')
)
Expand Down
6 changes: 6 additions & 0 deletions test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
parser = argparse.ArgumentParser(description='testing')
parser.add_argument('-n', '--tag', type=str, nargs='?', default='pre_trained_car',
help='set log tag')
parser.add_argument('-d', '--decrease', type=bool, nargs='?', default=False,
help='set the flag to True if decrease model')
parser.add_argument('-m', '--minimize', type=bool, nargs='?', default=False,
help='set the flag to True if minimize model')
parser.add_argument('-t', '--data-tag', type=str, nargs='?', default='000000',
help='set data tag')
parser.add_argument('-o', '--output-path', type=str, nargs='?',default='./predictions',
Expand Down Expand Up @@ -56,6 +60,8 @@ def main(_):
with tf.Session(config=config) as sess:
model = RPN3D(
cls=cfg.DETECT_OBJ,
decrease=args.decrease,
minimize=args.minimize,
single_batch_size=1,
avail_gpus=cfg.GPU_AVAILABLE.split(',')
)
Expand Down
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
help='max epoch')
parser.add_argument('-n', '--tag', type=str, nargs='?', default='default',
help='set log tag')
parser.add_argument('-d', '--decrease', type=bool, nargs='?', default=False,
help='set the flag to True if decrease model')
parser.add_argument('-m', '--minimize', type=bool, nargs='?', default=False,
help='set the flag to True if minimize model')
parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1,
help='set batch size for each gpu')
parser.add_argument('-l', '--lr', type=float, nargs='?', default=0.001,
Expand Down Expand Up @@ -68,6 +72,8 @@ def main(_):
with tf.Session(config=config) as sess:
model = RPN3D(
cls=cfg.DETECT_OBJ,
decrease=args.decrease,
minimize=args.minimize,
single_batch_size=args.single_batch_size,
learning_rate=args.lr,
max_gradient_norm=5.0,
Expand Down

0 comments on commit c175dfb

Please sign in to comment.