Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support OC-SORT for MOT #545

Merged
merged 18 commits into from
Aug 16, 2022
33 changes: 33 additions & 0 deletions configs/mot/ocsort/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking

## Abstract

<!-- [ABSTRACT] -->

Multi-Object Tracking (MOT) has rapidly progressed with the development of object detection and re-identification. However, motion modeling, which facilitates object association by forecasting short-term trajec- tories with past observations, has been relatively under-explored in recent years. Current motion models in MOT typically assume that the object motion is linear in a small time window and needs continuous observations, so these methods are sensitive to occlusions and non-linear motion and require high frame-rate videos. In this work, we show that a simple motion model can obtain state-of-the-art tracking performance without other cues like appearance. We emphasize the role of “observation” when recovering tracks from being lost and reducing the error accumulated by linear motion models during the lost period. We thus name the proposed method as Observation-Centric SORT, OC-SORT for short. It remains simple, online, and real-time but improves robustness over occlusion and non-linear motion. It achieves 63.2 and 62.1 HOTA on MOT17 and MOT20, respectively, surpassing all published methods. It also sets new states of the art on KITTI Pedestrian Tracking and DanceTrack where the object motion is highly non-linear

<!-- [IMAGE] -->
<div align="center">
<img src="https://user-images.githubusercontent.com/17743251/168193097-b3ad1a94-b18c-4b14-b7b1-5f8c6ed842f0.png"/>
</div>

## Citation

<!-- [ALGORITHM] -->

```latex
@article{cao2022observation,
title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking},
author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris},
journal={arXiv preprint arXiv:2203.14360},
year={2022}
}
```

## Results and models on MOT17

The performance on `MOT17-half-val` is comparable with the performance from [the OC-SORT official implementation](https://github.com/noahcao/OC_SORT). We use the same YOLO-X detector weights as in [ByteTrack](https://github.com/open-mmlab/mmtracking/tree/master/configs/mot/bytetrack).

| Method | Detector | Train Set | Test Set | Public | Inf time (fps) | HOTA | MOTA | IDF1 | FP | FN | IDSw. | Config | Download |
| :-----: | :------: | :---------------------: | :------: | :----: | :------------: | :---: | :---: | :---: | :---: | :---: | :---: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| OC-SORT | YOLOX-X | CrowdHuman + half-train | half-val | N | - | 67.7 | 76.4 | 77.5 | 18516 | 19014 | 909 | [config](ocsort_yolox_x_crowdhuman_mot17-private-half.py) | [model](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth) &#124; log |
39 changes: 39 additions & 0 deletions configs/mot/ocsort/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
Collections:
- Name: OCSORT
Metadata:
Training Techniques:
- SGD with Momentum
Training Resources: 8x V100 GPUs
Architecture:
- YOLOX
Paper:
URL: https://arxiv.org/abs/2203.14360
Title: Observation-Centric SORT Rethinking SORT for Robust Multi-Object Tracking
README: configs/mot/ocsort/README.md

Models:
- Name: ocsort_yolox_x_crowdhuman_mot17-private-half
In Collection: OCSORT
Config: configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private-half.py
Metadata:
Training Data: CrowdHuman + MOT17-half-train
Results:
- Task: Multiple Object Tracking
Dataset: MOT17-half-val
Metrics:
MOTA: 78.6
IDF1: 79.2
Weights: https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth

- Name: ocsort_yolox_x_crowdhuman_mot17-private
In Collection: OCSORT
Config: configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private.py
Metadata:
Training Data: CrowdHuman + MOT17-half-train
Results:
- Task: Multiple Object Tracking
Dataset: MOT17-test
Metrics:
MOTA: 78.1
IDF1: 74.8
Weights: https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth
167 changes: 167 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private-half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
_base_ = [
'../../_base_/models/yolox_x_8x8.py',
'../../_base_/datasets/mot_challenge.py', '../../_base_/default_runtime.py'
]

img_scale = (800, 1440)
samples_per_gpu = 4

model = dict(
type='OCSORT',
detector=dict(
input_size=img_scale,
random_size_range=(18, 32),
bbox_head=dict(num_classes=1),
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.7)),
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth' # noqa: E501
)),
motion=dict(type='KalmanFilter'),
tracker=dict(
type='OCSORTTracker',
obj_score_thr=0.3,
init_track_thr=0.7,
weight_iou_with_det_scores=True,
match_iou_thr=0.3,
num_tentatives=3,
vel_consist_weight=0.2,
vel_delta_t=3,
num_frames_retain=30))

train_pipeline = [
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
bbox_clip_border=False),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2),
bbox_clip_border=False),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0,
bbox_clip_border=False),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Resize',
img_scale=img_scale,
keep_ratio=True,
bbox_clip_border=False),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data = dict(
samples_per_gpu=samples_per_gpu,
workers_per_gpu=4,
persistent_workers=True,
train=dict(
_delete_=True,
type='MultiImageMixDataset',
dataset=dict(
type='CocoDataset',
ann_file=[
'data/MOT17/annotations/half-train_cocoformat.json',
'data/crowdhuman/annotations/crowdhuman_train.json',
'data/crowdhuman/annotations/crowdhuman_val.json'
],
img_prefix=[
'data/MOT17/train', 'data/crowdhuman/train',
'data/crowdhuman/val'
],
classes=('pedestrian', ),
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False),
pipeline=train_pipeline),
val=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)),
test=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)))

# optimizer
# default 8 gpu
optimizer = dict(
type='SGD',
lr=0.001 / 8 * samples_per_gpu,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0))
optimizer_config = dict(grad_clip=None)

# some hyper parameters
total_epochs = 80
num_last_epochs = 10
resume_from = None
interval = 5

# learning policy
lr_config = dict(
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=1,
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)

custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]

checkpoint_config = dict(interval=1)
evaluation = dict(metric=['bbox', 'track'], interval=1)
search_metrics = ['MOTA', 'IDF1', 'FN', 'FP', 'IDs', 'MT', 'ML']

# you need to set mode='dynamic' if you are using pytorch<=1.5.0
fp16 = dict(loss_scale=dict(init_scale=512.))
6 changes: 6 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot17-private.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['./ocsort_yolox_x_crowdhuman_mot17-private-half.py']

data = dict(
test=dict(
ann_file='data/MOT17/annotations/test_cocoformat.json',
img_prefix='data/MOT17/test'))
76 changes: 76 additions & 0 deletions configs/mot/ocsort/ocsort_yolox_x_crowdhuman_mot20-private-half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
_base_ = ['./ocsort_yolox_x_crowdhuman_mot17-private-half.py']

img_scale = (896, 1600)

model = dict(
detector=dict(input_size=img_scale, random_size_range=(20, 36)),
tracker=dict(
weight_iou_with_det_scores=False,
match_iou_thr=0.3,
))

train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(type='Pad', size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
to_rgb=False),
dict(
type='Pad',
size_divisor=32,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='ImageToTensor', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
data = dict(
train=dict(
dataset=dict(
ann_file=[
'data/MOT20/annotations/half-train_cocoformat.json',
'data/crowdhuman/annotations/crowdhuman_train.json',
'data/crowdhuman/annotations/crowdhuman_val.json'
],
img_prefix=[
'data/MOT20/train', 'data/crowdhuman/train',
'data/crowdhuman/val'
]),
pipeline=train_pipeline),
val=dict(
ann_file='data/MOT20/annotations/half-val_cocoformat.json',
img_prefix='data/MOT20/train',
pipeline=test_pipeline),
test=dict(
ann_file='data/MOT20/annotations/half-val_cocoformat.json',
img_prefix='data/MOT20/train',
pipeline=test_pipeline))

checkpoint_config = dict(interval=1)
evaluation = dict(metric=['bbox', 'track'], interval=100)
4 changes: 3 additions & 1 deletion mmtrack/models/mot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from .base import BaseMultiObjectTracker
from .byte_track import ByteTrack
from .deep_sort import DeepSORT
from .ocsort import OCSORT
from .qdtrack import QDTrack
from .tracktor import Tracktor

__all__ = [
'BaseMultiObjectTracker', 'Tracktor', 'DeepSORT', 'ByteTrack', 'QDTrack'
'BaseMultiObjectTracker', 'Tracktor', 'DeepSORT', 'ByteTrack', 'QDTrack',
'OCSORT'
]
Loading