Skip to content

Commit

Permalink
[Enhancement] Refactor TrackingNetDataset, SOTCocoDataset and SOTImag…
Browse files Browse the repository at this point in the history
…eNetVIDDataset based on BaseSOTDataset (open-mmlab#402)

* add trackingnet, sot_coco and sot_imagenetvid

* fix docs and naming issues

* change configs and fix test annotation bug

* fix docs

* fix formats

* fix unit test

* recover the original lasot json

* json format

* small changes

* small changes

* decrepate get_img_names_from_video and fix docs
  • Loading branch information
JingweiZhang12 authored Jan 25, 2022
1 parent 596e6cd commit bb0a643
Show file tree
Hide file tree
Showing 16 changed files with 429 additions and 141 deletions.
2 changes: 1 addition & 1 deletion configs/sot/siamese_rpn/siamese_rpn_r50_1x_lasot.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
]
test_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='LoadAnnotations', with_bbox=True, with_label=False),
dict(
type='MultiScaleFlipAug',
scale_factor=1,
Expand Down
4 changes: 2 additions & 2 deletions configs/sot/siamese_rpn/siamese_rpn_r50_1x_trackingnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
data = dict(
test=dict(
type='TrackingNetDataset',
ann_file=data_root + 'trackingnet/annotations/trackingnet_test.json',
img_prefix=data_root + 'trackingnet'))
img_prefix=data_root + 'trackingnet',
split='test'))
5 changes: 4 additions & 1 deletion mmtrack/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .parsers import CocoVID
from .pipelines import PIPELINES
from .reid_dataset import ReIDDataset
from .sot_coco_dataset import SOTCocoDataset
from .sot_imagenet_vid_dataset import SOTImageNetVIDDataset
from .sot_test_dataset import SOTTestDataset
from .sot_train_dataset import SOTTrainDataset
from .trackingnet_dataset import TrackingNetDataset
Expand All @@ -24,5 +26,6 @@
'CocoVideoDataset', 'ImagenetVIDDataset', 'MOTChallengeDataset',
'ReIDDataset', 'SOTTrainDataset', 'SOTTestDataset', 'LaSOTDataset',
'UAV123Dataset', 'TrackingNetDataset', 'OTB100Dataset',
'YouTubeVISDataset', 'GOT10kDataset', 'VOTDataset', 'BaseSOTDataset'
'YouTubeVISDataset', 'GOT10kDataset', 'VOTDataset', 'BaseSOTDataset',
'SOTCocoDataset', 'SOTImageNetVIDDataset'
]
36 changes: 12 additions & 24 deletions mmtrack/datasets/base_sot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def __init__(self,
{
'video_path': the video path
'ann_path': the annotation path
'start_frame_id': the starting frame number contained in
'start_frame_id': the starting frame ID number contained in
the image name
'end_frame_id': the ending frame number contained in the
'end_frame_id': the ending frame ID number contained in the
image name
'framename_template': the template of image name
},
Expand All @@ -78,33 +78,14 @@ def __getitem__(self, ind):
def load_data_infos(self, split='train'):
pass

def get_img_names_from_video(self, video_ind):
"""Get all frame paths in a video.
Args:
video_ind (int): video index
Returns:
list[str]: all image paths
"""
img_names = []
start_frame_id = self.data_infos[video_ind]['start_frame_id']
end_frame_id = self.data_infos[video_ind]['end_frame_id']
framename_template = self.data_infos[video_ind]['framename_template']
for frame_id in range(start_frame_id, end_frame_id + 1):
img_names.append(
osp.join(self.data_infos[video_ind]['video_path'],
framename_template % frame_id))
return img_names

def get_bboxes_from_video(self, video_ind):
"""Get bboxes annotation about the instance in a video.
Args:
video_ind (int): video index
Returns:
ndarray: in [N, 4] shape. The N is the bbox number and the bbox
ndarray: in [N, 4] shape. The N is the number of bbox and the bbox
is in (x, y, w, h) format.
"""
bbox_path = osp.join(self.img_prefix,
Expand All @@ -123,7 +104,7 @@ def get_bboxes_from_video(self, video_ind):
return bboxes

def get_len_per_video(self, video_ind):
"""Get the frame number in a video."""
"""Get the number of frames in a video."""
return self.data_infos[video_ind]['end_frame_id'] - self.data_infos[
video_ind]['start_frame_id'] + 1

Expand Down Expand Up @@ -168,7 +149,14 @@ def get_img_infos_from_video(self, video_ind):
Returns:
dict: {'filename': list[str], 'frame_ids':ndarray, 'video_id':int}
"""
img_names = self.get_img_names_from_video(video_ind)
img_names = []
start_frame_id = self.data_infos[video_ind]['start_frame_id']
end_frame_id = self.data_infos[video_ind]['end_frame_id']
framename_template = self.data_infos[video_ind]['framename_template']
for frame_id in range(start_frame_id, end_frame_id + 1):
img_names.append(
osp.join(self.data_infos[video_ind]['video_path'],
framename_template % frame_id))
frame_ids = np.arange(self.get_len_per_video(video_ind))
img_infos = dict(
filename=img_names, frame_ids=frame_ids, video_id=video_ind)
Expand Down
14 changes: 12 additions & 2 deletions mmtrack/datasets/got10k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@DATASETS.register_module()
class GOT10kDataset(BaseSOTDataset):
"""Dataset of single object tracking.
"""GOT10k Dataset of single object tracking.
The dataset can both support training and testing mode.
"""
Expand All @@ -28,7 +28,17 @@ def load_data_infos(self, split='train'):
split (str, optional): the split of dataset. Defaults to 'train'.
Returns:
list[dict]: the length of the list is the number of videos.
list[dict]: the length of the list is the number of videos. The
inner dict is in the following format:
{
'video_path': the video path
'ann_path': the annotation path
'start_frame_id': the starting frame number contained
in the image name
'end_frame_id': the ending frame number contained in
the image name
'framename_template': the template of image name
}
"""
print('Loading GOT10k dataset...')
start_time = time.time()
Expand Down
82 changes: 82 additions & 0 deletions mmtrack/datasets/sot_coco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time

import numpy as np
from mmdet.datasets import DATASETS
from pycocotools.coco import COCO

from .base_sot_dataset import BaseSOTDataset


@DATASETS.register_module()
class SOTCocoDataset(BaseSOTDataset):
"""Coco dataset of single object tracking.
The dataset only support training mode.
"""

def __init__(self, ann_file, *args, **kwargs):
"""Initialization of SOT dataset class.
Args:
ann_file (str): The official coco annotation file. It will be
loaded and parsed in the `self.load_data_infos` function.
"""
self.coco = COCO(ann_file)
super().__init__(*args, **kwargs)

def load_data_infos(self, split='train'):
"""Load dataset information. Each instance is viewed as a video.
Args:
split (str, optional): The split of dataset. Defaults to 'train'.
Returns:
list[int]: The length of the list is the number of valid object
annotations. The elemment in the list is annotation ID in coco
API.
"""
print('Loading Coco dataset...')
start_time = time.time()
ann_list = list(self.coco.anns.keys())
videos_list = [
ann for ann in ann_list
if self.coco.anns[ann].get('iscrowd', 0) == 0
]
print(f'Coco dataset loaded! ({time.time()-start_time:.2f} s)')
return videos_list

def get_bboxes_from_video(self, video_ind):
"""Get bbox annotation about the instance in an image.
Args:
video_ind (int): video index. Each video_ind denotes an instance.
Returns:
ndarray: in [1, 4] shape. The bbox is in (x, y, w, h) format.
"""
ann_id = self.data_infos[video_ind]
anno = self.coco.anns[ann_id]
bboxes = np.array(anno['bbox']).reshape(-1, 4)
return bboxes

def get_img_infos_from_video(self, video_ind):
"""Get all frame paths in a video.
Args:
video_ind (int): video index. Each video_ind denotes an instance.
Returns:
list[str]: all image paths
"""
ann_id = self.data_infos[video_ind]
imgs = self.coco.loadImgs([self.coco.anns[ann_id]['image_id']])
img_names = [img['file_name'] for img in imgs]
frame_ids = np.arange(self.get_len_per_video(video_ind))
img_infos = dict(
filename=img_names, frame_ids=frame_ids, video_id=video_ind)
return img_infos

def get_len_per_video(self, video_ind):
"""Get the number of frames in a video."""
return 1
132 changes: 132 additions & 0 deletions mmtrack/datasets/sot_imagenet_vid_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np
from mmdet.datasets import DATASETS

from mmtrack.datasets.parsers import CocoVID
from .base_sot_dataset import BaseSOTDataset


@DATASETS.register_module()
class SOTImageNetVIDDataset(BaseSOTDataset):
"""ImageNet VID dataset of single object tracking.
The dataset only support training mode.
"""

def __init__(self, ann_file, *args, **kwargs):
"""Initialization of SOT dataset class.
Args:
ann_file (str): The coco-format annotation file of ImageNet VID
Dataset. It will be loaded and parsed in the
`self.load_data_infos` function.
"""
self.coco = CocoVID(ann_file)
super().__init__(*args, **kwargs)

def load_data_infos(self, split='train'):
"""Load dataset information.
Args:
split (str, optional): The split of dataset. Defaults to 'train'.
Returns:
list[int]: The length of the list is the number of instances. The
elemment in the list is instance ID in coco API.
"""
data_infos = list(self.coco.instancesToImgs.keys())
return data_infos

def get_bboxes_from_video(self, video_ind):
"""Get bbox annotation about the instance in a video. Considering
`get_bboxes_from_video` in `SOTBaseDataset` is not compatible with
`SOTImageNetVIDDataset`, we oveload this function though it's not
called by `self.get_ann_infos_from_video`.
Args:
video_ind (int): video index. Each video_ind denotes an instance.
Returns:
ndarray: in [N, 4] shape. The bbox is in (x, y, w, h) format.
"""
instance_id = self.data_infos[video_ind]
img_ids = self.coco.instancesToImgs[instance_id]
bboxes = []
for img_id in img_ids:
for ann in self.coco.imgToAnns[img_id]:
if ann['instance_id'] == instance_id:
bboxes.append(ann['bbox'])
bboxes = np.array(bboxes).reshape(-1, 4)
return bboxes

def get_img_infos_from_video(self, video_ind):
"""Get image information in a video.
Args:
video_ind (int): video index
Returns:
dict: {'filename': list[str], 'frame_ids':ndarray, 'video_id':int}
"""
instance_id = self.data_infos[video_ind]
img_ids = self.coco.instancesToImgs[instance_id]
frame_ids = []
img_names = []
# In ImageNetVID dataset, frame_ids may not be continuous.
for img_id in img_ids:
frame_ids.append(self.coco.imgs[img_id]['frame_id'])
img_names.append(self.coco.imgs[img_id]['file_name'])
img_infos = dict(
filename=img_names, frame_ids=frame_ids, video_id=video_ind)
return img_infos

def get_ann_infos_from_video(self, video_ind):
"""Get annotation information in a video.
Note: We overload this function for speed up loading video information.
Args:
video_ind (int): video index. Each video_ind denotes an instance.
Returns:
dict: {'bboxes': ndarray in (N, 4) shape, 'bboxes_isvalid':
ndarray, 'visible':ndarray}. The bbox is in
(x1, y1, x2, y2) format.
"""
instance_id = self.data_infos[video_ind]
img_ids = self.coco.instancesToImgs[instance_id]
bboxes = []
visible = []
for img_id in img_ids:
for ann in self.coco.imgToAnns[img_id]:
if ann['instance_id'] == instance_id:
bboxes.append(ann['bbox'])
visible.append(not ann.get('occluded', False))
bboxes = np.array(bboxes).reshape(-1, 4)
bboxes_isvalid = (bboxes[:, 2] > self.bbox_min_size) & (
bboxes[:, 3] > self.bbox_min_size)
bboxes[:, 2:] += bboxes[:, :2]
visible = np.array(visible, dtype=np.bool_) & bboxes_isvalid
ann_infos = ann_infos = dict(
bboxes=bboxes, bboxes_isvalid=bboxes_isvalid, visible=visible)
return ann_infos

def get_visibility_from_video(self, video_ind):
"""Get the visible information in a video.
Considering `get_visibility_from_video` in `SOTBaseDataset` is not
compatible with `SOTImageNetVIDDataset`, we oveload this function
though it's not called by `self.get_ann_infos_from_video`.
"""
instance_id = self.data_infos[video_ind]
img_ids = self.coco.instancesToImgs[instance_id]
visible = []
for img_id in img_ids:
for ann in self.coco.imgToAnns[img_id]:
if ann['instance_id'] == instance_id:
visible.append(not ann.get('occluded', False))
visible_info = dict(visible=np.array(visible, dtype=np.bool_))
return visible_info

def get_len_per_video(self, video_ind):
"""Get the number of frames in a video."""
instance_id = self.data_infos[video_ind]
return len(self.coco.instancesToImgs[instance_id])
Loading

0 comments on commit bb0a643

Please sign in to comment.