forked from open-mmlab/mmtracking
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancement] Refactor TrackingNetDataset, SOTCocoDataset and SOTImag…
…eNetVIDDataset based on BaseSOTDataset (open-mmlab#402) * add trackingnet, sot_coco and sot_imagenetvid * fix docs and naming issues * change configs and fix test annotation bug * fix docs * fix formats * fix unit test * recover the original lasot json * json format * small changes * small changes * decrepate get_img_names_from_video and fix docs
- Loading branch information
1 parent
596e6cd
commit bb0a643
Showing
16 changed files
with
429 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import time | ||
|
||
import numpy as np | ||
from mmdet.datasets import DATASETS | ||
from pycocotools.coco import COCO | ||
|
||
from .base_sot_dataset import BaseSOTDataset | ||
|
||
|
||
@DATASETS.register_module() | ||
class SOTCocoDataset(BaseSOTDataset): | ||
"""Coco dataset of single object tracking. | ||
The dataset only support training mode. | ||
""" | ||
|
||
def __init__(self, ann_file, *args, **kwargs): | ||
"""Initialization of SOT dataset class. | ||
Args: | ||
ann_file (str): The official coco annotation file. It will be | ||
loaded and parsed in the `self.load_data_infos` function. | ||
""" | ||
self.coco = COCO(ann_file) | ||
super().__init__(*args, **kwargs) | ||
|
||
def load_data_infos(self, split='train'): | ||
"""Load dataset information. Each instance is viewed as a video. | ||
Args: | ||
split (str, optional): The split of dataset. Defaults to 'train'. | ||
Returns: | ||
list[int]: The length of the list is the number of valid object | ||
annotations. The elemment in the list is annotation ID in coco | ||
API. | ||
""" | ||
print('Loading Coco dataset...') | ||
start_time = time.time() | ||
ann_list = list(self.coco.anns.keys()) | ||
videos_list = [ | ||
ann for ann in ann_list | ||
if self.coco.anns[ann].get('iscrowd', 0) == 0 | ||
] | ||
print(f'Coco dataset loaded! ({time.time()-start_time:.2f} s)') | ||
return videos_list | ||
|
||
def get_bboxes_from_video(self, video_ind): | ||
"""Get bbox annotation about the instance in an image. | ||
Args: | ||
video_ind (int): video index. Each video_ind denotes an instance. | ||
Returns: | ||
ndarray: in [1, 4] shape. The bbox is in (x, y, w, h) format. | ||
""" | ||
ann_id = self.data_infos[video_ind] | ||
anno = self.coco.anns[ann_id] | ||
bboxes = np.array(anno['bbox']).reshape(-1, 4) | ||
return bboxes | ||
|
||
def get_img_infos_from_video(self, video_ind): | ||
"""Get all frame paths in a video. | ||
Args: | ||
video_ind (int): video index. Each video_ind denotes an instance. | ||
Returns: | ||
list[str]: all image paths | ||
""" | ||
ann_id = self.data_infos[video_ind] | ||
imgs = self.coco.loadImgs([self.coco.anns[ann_id]['image_id']]) | ||
img_names = [img['file_name'] for img in imgs] | ||
frame_ids = np.arange(self.get_len_per_video(video_ind)) | ||
img_infos = dict( | ||
filename=img_names, frame_ids=frame_ids, video_id=video_ind) | ||
return img_infos | ||
|
||
def get_len_per_video(self, video_ind): | ||
"""Get the number of frames in a video.""" | ||
return 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import numpy as np | ||
from mmdet.datasets import DATASETS | ||
|
||
from mmtrack.datasets.parsers import CocoVID | ||
from .base_sot_dataset import BaseSOTDataset | ||
|
||
|
||
@DATASETS.register_module() | ||
class SOTImageNetVIDDataset(BaseSOTDataset): | ||
"""ImageNet VID dataset of single object tracking. | ||
The dataset only support training mode. | ||
""" | ||
|
||
def __init__(self, ann_file, *args, **kwargs): | ||
"""Initialization of SOT dataset class. | ||
Args: | ||
ann_file (str): The coco-format annotation file of ImageNet VID | ||
Dataset. It will be loaded and parsed in the | ||
`self.load_data_infos` function. | ||
""" | ||
self.coco = CocoVID(ann_file) | ||
super().__init__(*args, **kwargs) | ||
|
||
def load_data_infos(self, split='train'): | ||
"""Load dataset information. | ||
Args: | ||
split (str, optional): The split of dataset. Defaults to 'train'. | ||
Returns: | ||
list[int]: The length of the list is the number of instances. The | ||
elemment in the list is instance ID in coco API. | ||
""" | ||
data_infos = list(self.coco.instancesToImgs.keys()) | ||
return data_infos | ||
|
||
def get_bboxes_from_video(self, video_ind): | ||
"""Get bbox annotation about the instance in a video. Considering | ||
`get_bboxes_from_video` in `SOTBaseDataset` is not compatible with | ||
`SOTImageNetVIDDataset`, we oveload this function though it's not | ||
called by `self.get_ann_infos_from_video`. | ||
Args: | ||
video_ind (int): video index. Each video_ind denotes an instance. | ||
Returns: | ||
ndarray: in [N, 4] shape. The bbox is in (x, y, w, h) format. | ||
""" | ||
instance_id = self.data_infos[video_ind] | ||
img_ids = self.coco.instancesToImgs[instance_id] | ||
bboxes = [] | ||
for img_id in img_ids: | ||
for ann in self.coco.imgToAnns[img_id]: | ||
if ann['instance_id'] == instance_id: | ||
bboxes.append(ann['bbox']) | ||
bboxes = np.array(bboxes).reshape(-1, 4) | ||
return bboxes | ||
|
||
def get_img_infos_from_video(self, video_ind): | ||
"""Get image information in a video. | ||
Args: | ||
video_ind (int): video index | ||
Returns: | ||
dict: {'filename': list[str], 'frame_ids':ndarray, 'video_id':int} | ||
""" | ||
instance_id = self.data_infos[video_ind] | ||
img_ids = self.coco.instancesToImgs[instance_id] | ||
frame_ids = [] | ||
img_names = [] | ||
# In ImageNetVID dataset, frame_ids may not be continuous. | ||
for img_id in img_ids: | ||
frame_ids.append(self.coco.imgs[img_id]['frame_id']) | ||
img_names.append(self.coco.imgs[img_id]['file_name']) | ||
img_infos = dict( | ||
filename=img_names, frame_ids=frame_ids, video_id=video_ind) | ||
return img_infos | ||
|
||
def get_ann_infos_from_video(self, video_ind): | ||
"""Get annotation information in a video. | ||
Note: We overload this function for speed up loading video information. | ||
Args: | ||
video_ind (int): video index. Each video_ind denotes an instance. | ||
Returns: | ||
dict: {'bboxes': ndarray in (N, 4) shape, 'bboxes_isvalid': | ||
ndarray, 'visible':ndarray}. The bbox is in | ||
(x1, y1, x2, y2) format. | ||
""" | ||
instance_id = self.data_infos[video_ind] | ||
img_ids = self.coco.instancesToImgs[instance_id] | ||
bboxes = [] | ||
visible = [] | ||
for img_id in img_ids: | ||
for ann in self.coco.imgToAnns[img_id]: | ||
if ann['instance_id'] == instance_id: | ||
bboxes.append(ann['bbox']) | ||
visible.append(not ann.get('occluded', False)) | ||
bboxes = np.array(bboxes).reshape(-1, 4) | ||
bboxes_isvalid = (bboxes[:, 2] > self.bbox_min_size) & ( | ||
bboxes[:, 3] > self.bbox_min_size) | ||
bboxes[:, 2:] += bboxes[:, :2] | ||
visible = np.array(visible, dtype=np.bool_) & bboxes_isvalid | ||
ann_infos = ann_infos = dict( | ||
bboxes=bboxes, bboxes_isvalid=bboxes_isvalid, visible=visible) | ||
return ann_infos | ||
|
||
def get_visibility_from_video(self, video_ind): | ||
"""Get the visible information in a video. | ||
Considering `get_visibility_from_video` in `SOTBaseDataset` is not | ||
compatible with `SOTImageNetVIDDataset`, we oveload this function | ||
though it's not called by `self.get_ann_infos_from_video`. | ||
""" | ||
instance_id = self.data_infos[video_ind] | ||
img_ids = self.coco.instancesToImgs[instance_id] | ||
visible = [] | ||
for img_id in img_ids: | ||
for ann in self.coco.imgToAnns[img_id]: | ||
if ann['instance_id'] == instance_id: | ||
visible.append(not ann.get('occluded', False)) | ||
visible_info = dict(visible=np.array(visible, dtype=np.bool_)) | ||
return visible_info | ||
|
||
def get_len_per_video(self, video_ind): | ||
"""Get the number of frames in a video.""" | ||
instance_id = self.data_infos[video_ind] | ||
return len(self.coco.instancesToImgs[instance_id]) |
Oops, something went wrong.