Skip to content

Commit

Permalink
add temporal model
Browse files Browse the repository at this point in the history
  • Loading branch information
HaydenFaulkner committed Dec 11, 2019
1 parent 99c48b4 commit d18c8c5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
6 changes: 5 additions & 1 deletion models/definitions/yolo/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from gluoncv.model_zoo import get_model

from .yolo3 import YOLOV3_noback, YOLOV3, YOLOV3T, YOLOV3TS, YOLOV3TB
from .yolo3_temporal import YOLOV3Temporal
from ..darknet.three_darknet import get_darknet
from ..darknet.h_darknet import get_hdarknet
from ..darknet.ts_darknet import get_darknet_flownet, get_darknet_r21d
Expand All @@ -11,7 +12,7 @@
def yolo3_darknet53(classes, pretrained_base=True, norm_layer=BatchNorm, norm_kwargs=None, freeze_base=False,
k=None, k_join_type=None, k_join_pos=None, block_conv_type='2', rnn_pos=None,
corr_pos=None, corr_d=None, motion_stream=None, add_type=None, agnostic=False, new_model=False,
hierarchical=[1,1,1,1,1], h_join_type=None, **kwargs):
hierarchical=[1,1,1,1,1], h_join_type=None, temporal=False, **kwargs):
"""YOLO3 multi-scale with darknet53 base network on any dataset. Modified from:
https://github.com/dmlc/gluon-cv/blob/0dbd05c5eb8537c25b64f0e87c09be979303abf2/gluoncv/model_zoo/yolo/yolo3.py
Expand Down Expand Up @@ -93,6 +94,9 @@ def yolo3_darknet53(classes, pretrained_base=True, norm_layer=BatchNorm, norm_kw
k_join_pos=k_join_pos, block_conv_type=block_conv_type, rnn_shapes=rnn_shapes,
rnn_pos=rnn_pos,
corr_pos=corr_pos, corr_d=corr_d, agnostic=agnostic, **kwargs)
elif temporal:
net = YOLOV3Temporal(stages, [512, 256, 128], anchors, strides,
classes=classes, t=k, conv=int(block_conv_type), corr_d=corr_d, **kwargs)
else:
# OLD CODE
net = YOLOV3T(stages, [512, 256, 128], anchors, strides, classes=classes, k=k, k_join_type=k_join_type,
Expand Down
20 changes: 13 additions & 7 deletions train_yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from datasets.imgnetdet import ImageNetDetection
from datasets.imgnetvid import ImageNetVidDetection

from metrics.pascalvoc import VOCMApMetric
from metrics.pascalvoc import VOCMApMetric, VOCMApMetricTemporal
from metrics.mscoco import COCODetectionMetric

from models.definitions.yolo.wrappers import yolo3_darknet53, yolo3_no_backbone, yolo3_3ddarknet
Expand Down Expand Up @@ -139,7 +139,7 @@
"position of RNN, currently only supports 'late' or 'out")
flags.DEFINE_string('corr_pos', None,
"position of correlation features calculation, currently only supports 'early' or 'late")
flags.DEFINE_integer('corr_d', 4,
flags.DEFINE_integer('corr_d', 0,
'The d value for the correlation filter.')
flags.DEFINE_string('motion_stream', None,
'Add a motion stream? can be flownet or r21d.')
Expand Down Expand Up @@ -177,7 +177,10 @@ def get_dataset(dataset_name, save_prefix=''):
val_dataset = ImageNetVidDetection(splits=[(2017, 'val')], allow_empty=FLAGS.allow_empty,
every=FLAGS.every, window=FLAGS.window, features_dir=FLAGS.features_dir,
mult_out=FLAGS.mult_out)
val_metric = VOCMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
if FLAGS.mult_out:
val_metric = VOCMApMetricTemporal(t=int(FLAGS.window[0]), iou_thresh=0.5, class_names=val_dataset.classes)
else:
val_metric = VOCMApMetric(iou_thresh=0.5, class_names=val_dataset.classes)

else:
raise NotImplementedError('Dataset: {} not implemented.'.format(dataset_name))
Expand Down Expand Up @@ -207,7 +210,6 @@ def get_dataloader(net, train_dataset, val_dataset, batch_size):

return train_loader, val_loader


# stack image, all targets generated
batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)]))

Expand Down Expand Up @@ -306,7 +308,8 @@ def get_net(trained_on_dataset, ctx, definition='ours'):
block_conv_type=FLAGS.block_conv_type, rnn_pos=FLAGS.rnn_pos,
corr_pos=FLAGS.corr_pos, corr_d=FLAGS.corr_d, motion_stream=FLAGS.motion_stream,
add_type=FLAGS.stream_gating, new_model=FLAGS.new_model,
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type)
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type,
temporal=FLAGS.mult_out)
async_net = yolo3_darknet53(trained_on_dataset.classes,
pretrained_base=False,
freeze_base=bool(FLAGS.freeze_base),
Expand All @@ -315,7 +318,8 @@ def get_net(trained_on_dataset, ctx, definition='ours'):
corr_pos=FLAGS.corr_pos, corr_d=FLAGS.corr_d,
motion_stream=FLAGS.motion_stream, add_type=FLAGS.stream_gating,
new_model=FLAGS.new_model,
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type) # used by cpu worker
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type,
temporal=FLAGS.mult_out) # used by cpu worker
else:
net = yolo3_3ddarknet(trained_on_dataset.classes,
pretrained_base=FLAGS.pretrained_cnn,
Expand All @@ -336,7 +340,8 @@ def get_net(trained_on_dataset, ctx, definition='ours'):
block_conv_type=FLAGS.block_conv_type, rnn_pos=FLAGS.rnn_pos,
corr_pos=FLAGS.corr_pos, corr_d=FLAGS.corr_d, motion_stream=FLAGS.motion_stream,
add_type=FLAGS.stream_gating, new_model=FLAGS.new_model,
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type)
hierarchical=FLAGS.hier, h_join_type=FLAGS.h_join_type,
temporal=FLAGS.mult_out)
async_net = net
else:
net = yolo3_3ddarknet(trained_on_dataset.classes,
Expand Down Expand Up @@ -574,6 +579,7 @@ def train(net, train_data, train_dataset, val_data, eval_metric, ctx, save_prefi
center_metrics.update(0, center_losses)
scale_metrics.update(0, scale_losses)
cls_metrics.update(0, cls_losses)

if FLAGS.log_interval and not (i + 1) % FLAGS.log_interval:
name1, loss1 = obj_metrics.get()
name2, loss2 = center_metrics.get()
Expand Down

0 comments on commit d18c8c5

Please sign in to comment.