forked from open-mmlab/mmyolo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Show YOLOv5 assigner results (open-mmlab#383)
* init commit * init commit * init commit * 定稿,开始重构 * format code * format code * add typehint and doc * init commit * init commit * init commit * 定稿,开始重构 * format code * format code * add typehint and doc * format code * rollback * add doc * fix less img bug * format code * format code * add README.md * beauty * beauty * uniform name * uniform name * uniform name * uniform name
- Loading branch information
Showing
9 changed files
with
721 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# MMYOLO Model Assigner Visualization | ||
|
||
<img src="https://user-images.githubusercontent.com/40284075/208255302-dbcf8cb0-b9d1-495f-8908-57dd2370dba8.png"/> | ||
|
||
## Introduction | ||
|
||
This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image. | ||
|
||
Now, the script only support `YOLOv5` . | ||
|
||
## Usage | ||
|
||
### Command | ||
|
||
```shell | ||
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py ` | ||
``` |
151 changes: 151 additions & 0 deletions
151
projects/assigner_visualization/assigner_visualization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import argparse | ||
import os | ||
import os.path as osp | ||
import sys | ||
|
||
import mmcv | ||
import numpy as np | ||
import torch | ||
from mmengine import ProgressBar | ||
from mmengine.config import Config, DictAction | ||
from mmengine.dataset import COLLATE_FUNCTIONS | ||
from numpy import random | ||
|
||
from mmyolo.registry import DATASETS, MODELS | ||
from mmyolo.utils import register_all_modules | ||
from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner | ||
from projects.assigner_visualization.visualization import \ | ||
YOLOAssignerVisualizer | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='MMYOLO show the positive sample assigning' | ||
' results.') | ||
parser.add_argument('config', help='config file path') | ||
parser.add_argument( | ||
'--show-number', | ||
'-n', | ||
type=int, | ||
default=sys.maxsize, | ||
help='number of images selected to save, ' | ||
'must bigger than 0. if the number is bigger than length ' | ||
'of dataset, show all the images in dataset; ' | ||
'default "sys.maxsize", show all images in dataset') | ||
parser.add_argument( | ||
'--output-dir', | ||
default='assigned_results', | ||
type=str, | ||
help='The name of the folder where the image is saved.') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference.') | ||
parser.add_argument( | ||
'--show-prior', | ||
default=False, | ||
action='store_true', | ||
help='Whether to show prior on image.') | ||
parser.add_argument( | ||
'--not-show-label', | ||
default=False, | ||
action='store_true', | ||
help='Whether to show label on image.') | ||
parser.add_argument('--seed', default=-1, type=int, help='random seed') | ||
parser.add_argument( | ||
'--cfg-options', | ||
nargs='+', | ||
action=DictAction, | ||
help='override some settings in the used config, the key-value pair ' | ||
'in xxx=yyy format will be merged into config file. If the value to ' | ||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | ||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | ||
'Note that the quotation marks are necessary and that no white space ' | ||
'is allowed.') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
register_all_modules() | ||
|
||
# set random seed | ||
seed = int(args.seed) | ||
if seed != -1: | ||
print(f'Set the global seed: {seed}') | ||
random.seed(int(args.seed)) | ||
|
||
cfg = Config.fromfile(args.config) | ||
if args.cfg_options is not None: | ||
cfg.merge_from_dict(args.cfg_options) | ||
|
||
# build model | ||
model = MODELS.build(cfg.model) | ||
assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\ | ||
'Now, this script only support yolov5, and bbox_head must use ' \ | ||
'`YOLOv5HeadAssigner`. Please use `' \ | ||
'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \ | ||
'` as config file.' | ||
model.eval() | ||
model.to(args.device) | ||
|
||
# build dataset | ||
dataset_cfg = cfg.get('train_dataloader').get('dataset') | ||
dataset = DATASETS.build(dataset_cfg) | ||
|
||
# get collate_fn | ||
collate_fn_cfg = cfg.get('train_dataloader').pop( | ||
'collate_fn', dict(type='pseudo_collate')) | ||
collate_fn_type = collate_fn_cfg.pop('type') | ||
collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type) | ||
|
||
# init visualizer | ||
visualizer = YOLOAssignerVisualizer( | ||
vis_backends=[{ | ||
'type': 'LocalVisBackend' | ||
}], name='visualizer') | ||
visualizer.dataset_meta = dataset.metainfo | ||
# need priors size to draw priors | ||
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors | ||
|
||
# make output dir | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
# init visualization image number | ||
assert args.show_number > 0 | ||
display_number = min(args.show_number, len(dataset)) | ||
|
||
progress_bar = ProgressBar(display_number) | ||
for ind_img in range(display_number): | ||
data = dataset.prepare_data(ind_img) | ||
|
||
# convert data to batch format | ||
batch_data = collate_fn([data]) | ||
with torch.no_grad(): | ||
assign_results = model.assign(batch_data) | ||
|
||
img = data['inputs'].cpu().numpy().astype(np.uint8).transpose( | ||
(1, 2, 0)) | ||
# bgr2rgb | ||
img = mmcv.bgr2rgb(img) | ||
|
||
gt_instances = data['data_samples'].gt_instances | ||
|
||
img_show = visualizer.draw_assign(img, assign_results, gt_instances, | ||
args.show_prior, args.not_show_label) | ||
|
||
if hasattr(data['data_samples'], 'img_path'): | ||
filename = osp.basename(data['data_samples'].img_path) | ||
else: | ||
# some dataset have not image path | ||
filename = f'{ind_img}.jpg' | ||
out_file = osp.join(args.output_dir, filename) | ||
|
||
# convert rgb 2 bgr and save img | ||
mmcv.imwrite(mmcv.rgb2bgr(img_show), out_file) | ||
progress_bar.update() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
11 changes: 11 additions & 0 deletions
11
...r_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
_base_ = [ | ||
'../../../configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py' | ||
] | ||
|
||
custom_imports = dict(imports=[ | ||
'projects.assigner_visualization.detectors', | ||
'projects.assigner_visualization.dense_heads' | ||
]) | ||
|
||
model = dict( | ||
type='YOLODetectorAssigner', bbox_head=dict(type='YOLOv5HeadAssigner')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .yolov5_head_assigner import YOLOv5HeadAssigner | ||
|
||
__all__ = ['YOLOv5HeadAssigner'] |
188 changes: 188 additions & 0 deletions
188
projects/assigner_visualization/dense_heads/yolov5_head_assigner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Sequence, Union | ||
|
||
import torch | ||
from mmdet.models.utils import unpack_gt_instances | ||
from mmengine.structures import InstanceData | ||
from torch import Tensor | ||
|
||
from mmyolo.models import YOLOv5Head | ||
from mmyolo.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class YOLOv5HeadAssigner(YOLOv5Head): | ||
|
||
def assign_by_gt_and_feat( | ||
self, | ||
batch_gt_instances: Sequence[InstanceData], | ||
batch_img_metas: Sequence[dict], | ||
inputs_hw: Union[Tensor, tuple] = (640, 640) | ||
) -> dict: | ||
"""Calculate the assigning results based on the gt and features | ||
extracted by the detection head. | ||
Args: | ||
batch_gt_instances (Sequence[InstanceData]): Batch of | ||
gt_instance. It usually includes ``bboxes`` and ``labels`` | ||
attributes. | ||
batch_img_metas (Sequence[dict]): Meta information of each image, | ||
e.g., image size, scaling factor, etc. | ||
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): | ||
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | ||
data that is ignored during training and testing. | ||
Defaults to None. | ||
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size. | ||
Returns: | ||
dict[str, Tensor]: A dictionary of assigning results. | ||
""" | ||
# 1. Convert gt to norm format | ||
batch_targets_normed = self._convert_gt_to_norm_format( | ||
batch_gt_instances, batch_img_metas) | ||
|
||
device = batch_targets_normed.device | ||
scaled_factor = torch.ones(7, device=device) | ||
gt_inds = torch.arange( | ||
batch_targets_normed.shape[1], | ||
dtype=torch.long, | ||
device=device, | ||
requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1)) | ||
|
||
assign_results = [] | ||
for i in range(self.num_levels): | ||
assign_results_feat = [] | ||
h = inputs_hw[0] // self.featmap_strides[i] | ||
w = inputs_hw[1] // self.featmap_strides[i] | ||
|
||
# empty gt bboxes | ||
if batch_targets_normed.shape[1] == 0: | ||
for k in range(self.num_base_priors): | ||
assign_results_feat.append({ | ||
'stride': | ||
self.featmap_strides[i], | ||
'grid_x_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'grid_y_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'img_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'class_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'retained_gt_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'prior_ind': | ||
k | ||
}) | ||
assign_results.append(assign_results_feat) | ||
continue | ||
|
||
priors_base_sizes_i = self.priors_base_sizes[i] | ||
# feature map scale whwh | ||
scaled_factor[2:6] = torch.tensor([w, h, w, h]) | ||
# Scale batch_targets from range 0-1 to range 0-features_maps size. | ||
# (num_base_priors, num_bboxes, 7) | ||
batch_targets_scaled = batch_targets_normed * scaled_factor | ||
|
||
# 2. Shape match | ||
wh_ratio = batch_targets_scaled[..., | ||
4:6] / priors_base_sizes_i[:, None] | ||
match_inds = torch.max( | ||
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr | ||
batch_targets_scaled = batch_targets_scaled[match_inds] | ||
match_gt_inds = gt_inds[match_inds] | ||
|
||
# no gt bbox matches anchor | ||
if batch_targets_scaled.shape[0] == 0: | ||
for k in range(self.num_base_priors): | ||
assign_results_feat.append({ | ||
'stride': | ||
self.featmap_strides[i], | ||
'grid_x_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'grid_y_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'img_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'class_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'retained_gt_inds': | ||
torch.zeros([0], dtype=torch.int64).to(device), | ||
'prior_ind': | ||
k | ||
}) | ||
assign_results.append(assign_results_feat) | ||
continue | ||
|
||
# 3. Positive samples with additional neighbors | ||
|
||
# check the left, up, right, bottom sides of the | ||
# targets grid, and determine whether assigned | ||
# them as positive samples as well. | ||
batch_targets_cxcy = batch_targets_scaled[:, 2:4] | ||
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy | ||
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) & | ||
(batch_targets_cxcy > 1)).T | ||
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) & | ||
(grid_xy > 1)).T | ||
offset_inds = torch.stack( | ||
(torch.ones_like(left), left, up, right, bottom)) | ||
|
||
batch_targets_scaled = batch_targets_scaled.repeat( | ||
(5, 1, 1))[offset_inds] | ||
retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds] | ||
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1], | ||
1)[offset_inds] | ||
|
||
# prepare pred results and positive sample indexes to | ||
# calculate class loss and bbox lo | ||
_chunk_targets = batch_targets_scaled.chunk(4, 1) | ||
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets | ||
priors_inds, (img_inds, class_inds) = priors_inds.long().view( | ||
-1), img_class_inds.long().T | ||
|
||
grid_xy_long = (grid_xy - | ||
retained_offsets * self.near_neighbor_thr).long() | ||
grid_x_inds, grid_y_inds = grid_xy_long.T | ||
for k in range(self.num_base_priors): | ||
retained_inds = priors_inds == k | ||
assign_results_prior = { | ||
'stride': self.featmap_strides[i], | ||
'grid_x_inds': grid_x_inds[retained_inds], | ||
'grid_y_inds': grid_y_inds[retained_inds], | ||
'img_inds': img_inds[retained_inds], | ||
'class_inds': class_inds[retained_inds], | ||
'retained_gt_inds': retained_gt_inds[retained_inds], | ||
'prior_ind': k | ||
} | ||
assign_results_feat.append(assign_results_prior) | ||
assign_results.append(assign_results_feat) | ||
return assign_results | ||
|
||
def assign(self, batch_data_samples: Union[list, dict], | ||
inputs_hw: Union[tuple, torch.Size]) -> dict: | ||
"""Calculate assigning results. This function is provided to the | ||
`assigner_visualization.py` script. | ||
Args: | ||
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data | ||
Samples. It usually includes information such as | ||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | ||
inputs_hw: Height and width of inputs size | ||
Returns: | ||
dict: A dictionary of assigning components. | ||
""" | ||
if isinstance(batch_data_samples, list): | ||
outputs = unpack_gt_instances(batch_data_samples) | ||
(batch_gt_instances, batch_gt_instances_ignore, | ||
batch_img_metas) = outputs | ||
|
||
assign_inputs = (batch_gt_instances, batch_img_metas, | ||
batch_gt_instances_ignore, inputs_hw) | ||
else: | ||
# Fast version | ||
assign_inputs = (batch_data_samples['bboxes_labels'], | ||
batch_data_samples['img_metas'], inputs_hw) | ||
assign_results = self.assign_by_gt_and_feat(*assign_inputs) | ||
|
||
return assign_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from projects.assigner_visualization.detectors.yolo_detector_assigner import \ | ||
YOLODetectorAssigner | ||
|
||
__all__ = ['YOLODetectorAssigner'] |
Oops, something went wrong.