Skip to content

Commit

Permalink
dont load net if predictions exist
Browse files Browse the repository at this point in the history
  • Loading branch information
HaydenFaulkner committed Feb 27, 2020
1 parent d2c8875 commit b21923e
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions detect_yolo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,31 +719,6 @@ def main(_argv):
ctx = [mx.gpu(int(i)) for i in gpus]
ctx = ctx if ctx else [mx.cpu()]

# dataloader
loader = get_dataloader(dataset, batch_size)

# setup network
# net_name = '_'.join(('yolo3', FLAGS.network, 'custom'))
# net = get_model(net_name, root='models', pretrained_base=True, classes=trained_on_dataset.classes)
if FLAGS.network == 'darknet53':
if FLAGS.conv_types[0] is 2:
net = yolo3_darknet53(trained_on_dataset.classes,
k=FLAGS.window[0], k_join_type=FLAGS.k_join_type, k_join_pos=FLAGS.k_join_pos,
block_conv_type=FLAGS.block_conv_type, rnn_pos=FLAGS.rnn_pos,
corr_pos=FLAGS.corr_pos, corr_d=FLAGS.corr_d, motion_stream=FLAGS.motion_stream,
agnostic=FLAGS.model_agnostic, add_type=FLAGS.stream_gating, new_model=FLAGS.new_model,
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type, temporal=FLAGS.temp, t_out=FLAGS.mult_out)
else:
net = yolo3_3ddarknet(trained_on_dataset.classes, conv_types=FLAGS.conv_types)
else:
raise NotImplementedError('Backbone CNN model {} not implemented.'.format(FLAGS.network))
net.initialize()
if FLAGS.window[0] > 1:
net.summary(mx.nd.random_normal(shape=(1, FLAGS.window[0], 3, FLAGS.data_shape, FLAGS.data_shape)))
else:
net.summary(mx.nd.random_normal(shape=(1, 3, FLAGS.data_shape, FLAGS.data_shape)))
net.load_parameters(model_path)

max_do = FLAGS.max_do
if max_do < 0:
max_do = len(dataset)
Expand All @@ -767,6 +742,33 @@ def main(_argv):
agnostic=FLAGS.model_agnostic)

if predictions is None: # id not exist detect and make
# dataloader
loader = get_dataloader(dataset, batch_size)

# setup network
# net_name = '_'.join(('yolo3', FLAGS.network, 'custom'))
# net = get_model(net_name, root='models', pretrained_base=True, classes=trained_on_dataset.classes)
if FLAGS.network == 'darknet53':
if FLAGS.conv_types[0] is 2:
net = yolo3_darknet53(trained_on_dataset.classes,
k=FLAGS.window[0], k_join_type=FLAGS.k_join_type, k_join_pos=FLAGS.k_join_pos,
block_conv_type=FLAGS.block_conv_type, rnn_pos=FLAGS.rnn_pos,
corr_pos=FLAGS.corr_pos, corr_d=FLAGS.corr_d, motion_stream=FLAGS.motion_stream,
agnostic=FLAGS.model_agnostic, add_type=FLAGS.stream_gating,
new_model=FLAGS.new_model,
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type, temporal=FLAGS.temp,
t_out=FLAGS.mult_out)
else:
net = yolo3_3ddarknet(trained_on_dataset.classes, conv_types=FLAGS.conv_types)
else:
raise NotImplementedError('Backbone CNN model {} not implemented.'.format(FLAGS.network))
net.initialize()
if FLAGS.window[0] > 1:
net.summary(mx.nd.random_normal(shape=(1, FLAGS.window[0], 3, FLAGS.data_shape, FLAGS.data_shape)))
else:
net.summary(mx.nd.random_normal(shape=(1, 3, FLAGS.data_shape, FLAGS.data_shape)))
net.load_parameters(model_path)

predictions = detect(net, dataset, loader, ctx, max_do=max_do) # todo fix det thresh
save_predictions(save_dir, dataset, predictions, agnostic=FLAGS.model_agnostic)

Expand Down

0 comments on commit b21923e

Please sign in to comment.