Skip to content

Commit

Permalink
[Feature] Show YOLOv5 assigner results (open-mmlab#383)
Browse files Browse the repository at this point in the history
* init commit

* init commit

* init commit

* 定稿,开始重构

* format code

* format code

* add typehint and doc

* init commit

* init commit

* init commit

* 定稿,开始重构

* format code

* format code

* add typehint and doc

* format code

* rollback

* add doc

* fix less img bug

* format code

* format code

* add README.md

* beauty

* beauty

* uniform name

* uniform name

* uniform name

* uniform name
  • Loading branch information
Nioolek authored Dec 28, 2022
1 parent bb4aea9 commit 9ef8831
Show file tree
Hide file tree
Showing 9 changed files with 721 additions and 0 deletions.
17 changes: 17 additions & 0 deletions projects/assigner_visualization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# MMYOLO Model Assigner Visualization

<img src="https://user-images.githubusercontent.com/40284075/208255302-dbcf8cb0-b9d1-495f-8908-57dd2370dba8.png"/>

## Introduction

This project is developed for easily showing assigning results. The script allows users to analyze where and how many positive samples each gt is assigned in the image.

Now, the script only support `YOLOv5` .

## Usage

### Command

```shell
python projects/assigner_visualization/assigner_visualization.py projects/assigner_visualization/configs/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py `
```
151 changes: 151 additions & 0 deletions projects/assigner_visualization/assigner_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import sys

import mmcv
import numpy as np
import torch
from mmengine import ProgressBar
from mmengine.config import Config, DictAction
from mmengine.dataset import COLLATE_FUNCTIONS
from numpy import random

from mmyolo.registry import DATASETS, MODELS
from mmyolo.utils import register_all_modules
from projects.assigner_visualization.dense_heads import YOLOv5HeadAssigner
from projects.assigner_visualization.visualization import \
YOLOAssignerVisualizer


def parse_args():
parser = argparse.ArgumentParser(
description='MMYOLO show the positive sample assigning'
' results.')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--show-number',
'-n',
type=int,
default=sys.maxsize,
help='number of images selected to save, '
'must bigger than 0. if the number is bigger than length '
'of dataset, show all the images in dataset; '
'default "sys.maxsize", show all images in dataset')
parser.add_argument(
'--output-dir',
default='assigned_results',
type=str,
help='The name of the folder where the image is saved.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(
'--show-prior',
default=False,
action='store_true',
help='Whether to show prior on image.')
parser.add_argument(
'--not-show-label',
default=False,
action='store_true',
help='Whether to show label on image.')
parser.add_argument('--seed', default=-1, type=int, help='random seed')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')

args = parser.parse_args()
return args


def main():
args = parse_args()
register_all_modules()

# set random seed
seed = int(args.seed)
if seed != -1:
print(f'Set the global seed: {seed}')
random.seed(int(args.seed))

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# build model
model = MODELS.build(cfg.model)
assert isinstance(model.bbox_head, YOLOv5HeadAssigner),\
'Now, this script only support yolov5, and bbox_head must use ' \
'`YOLOv5HeadAssigner`. Please use `' \
'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_assignervisualization.py' \
'` as config file.'
model.eval()
model.to(args.device)

# build dataset
dataset_cfg = cfg.get('train_dataloader').get('dataset')
dataset = DATASETS.build(dataset_cfg)

# get collate_fn
collate_fn_cfg = cfg.get('train_dataloader').pop(
'collate_fn', dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type)

# init visualizer
visualizer = YOLOAssignerVisualizer(
vis_backends=[{
'type': 'LocalVisBackend'
}], name='visualizer')
visualizer.dataset_meta = dataset.metainfo
# need priors size to draw priors
visualizer.priors_size = model.bbox_head.prior_generator.base_anchors

# make output dir
os.makedirs(args.output_dir, exist_ok=True)

# init visualization image number
assert args.show_number > 0
display_number = min(args.show_number, len(dataset))

progress_bar = ProgressBar(display_number)
for ind_img in range(display_number):
data = dataset.prepare_data(ind_img)

# convert data to batch format
batch_data = collate_fn([data])
with torch.no_grad():
assign_results = model.assign(batch_data)

img = data['inputs'].cpu().numpy().astype(np.uint8).transpose(
(1, 2, 0))
# bgr2rgb
img = mmcv.bgr2rgb(img)

gt_instances = data['data_samples'].gt_instances

img_show = visualizer.draw_assign(img, assign_results, gt_instances,
args.show_prior, args.not_show_label)

if hasattr(data['data_samples'], 'img_path'):
filename = osp.basename(data['data_samples'].img_path)
else:
# some dataset have not image path
filename = f'{ind_img}.jpg'
out_file = osp.join(args.output_dir, filename)

# convert rgb 2 bgr and save img
mmcv.imwrite(mmcv.rgb2bgr(img_show), out_file)
progress_bar.update()


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../../../configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
]

custom_imports = dict(imports=[
'projects.assigner_visualization.detectors',
'projects.assigner_visualization.dense_heads'
])

model = dict(
type='YOLODetectorAssigner', bbox_head=dict(type='YOLOv5HeadAssigner'))
4 changes: 4 additions & 0 deletions projects/assigner_visualization/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .yolov5_head_assigner import YOLOv5HeadAssigner

__all__ = ['YOLOv5HeadAssigner']
188 changes: 188 additions & 0 deletions projects/assigner_visualization/dense_heads/yolov5_head_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

import torch
from mmdet.models.utils import unpack_gt_instances
from mmengine.structures import InstanceData
from torch import Tensor

from mmyolo.models import YOLOv5Head
from mmyolo.registry import MODELS


@MODELS.register_module()
class YOLOv5HeadAssigner(YOLOv5Head):

def assign_by_gt_and_feat(
self,
batch_gt_instances: Sequence[InstanceData],
batch_img_metas: Sequence[dict],
inputs_hw: Union[Tensor, tuple] = (640, 640)
) -> dict:
"""Calculate the assigning results based on the gt and features
extracted by the detection head.
Args:
batch_gt_instances (Sequence[InstanceData]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (Sequence[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
inputs_hw (Union[Tensor, tuple]): Height and width of inputs size.
Returns:
dict[str, Tensor]: A dictionary of assigning results.
"""
# 1. Convert gt to norm format
batch_targets_normed = self._convert_gt_to_norm_format(
batch_gt_instances, batch_img_metas)

device = batch_targets_normed.device
scaled_factor = torch.ones(7, device=device)
gt_inds = torch.arange(
batch_targets_normed.shape[1],
dtype=torch.long,
device=device,
requires_grad=False).unsqueeze(0).repeat((self.num_base_priors, 1))

assign_results = []
for i in range(self.num_levels):
assign_results_feat = []
h = inputs_hw[0] // self.featmap_strides[i]
w = inputs_hw[1] // self.featmap_strides[i]

# empty gt bboxes
if batch_targets_normed.shape[1] == 0:
for k in range(self.num_base_priors):
assign_results_feat.append({
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
k
})
assign_results.append(assign_results_feat)
continue

priors_base_sizes_i = self.priors_base_sizes[i]
# feature map scale whwh
scaled_factor[2:6] = torch.tensor([w, h, w, h])
# Scale batch_targets from range 0-1 to range 0-features_maps size.
# (num_base_priors, num_bboxes, 7)
batch_targets_scaled = batch_targets_normed * scaled_factor

# 2. Shape match
wh_ratio = batch_targets_scaled[...,
4:6] / priors_base_sizes_i[:, None]
match_inds = torch.max(
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
batch_targets_scaled = batch_targets_scaled[match_inds]
match_gt_inds = gt_inds[match_inds]

# no gt bbox matches anchor
if batch_targets_scaled.shape[0] == 0:
for k in range(self.num_base_priors):
assign_results_feat.append({
'stride':
self.featmap_strides[i],
'grid_x_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'grid_y_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'img_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'class_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'retained_gt_inds':
torch.zeros([0], dtype=torch.int64).to(device),
'prior_ind':
k
})
assign_results.append(assign_results_feat)
continue

# 3. Positive samples with additional neighbors

# check the left, up, right, bottom sides of the
# targets grid, and determine whether assigned
# them as positive samples as well.
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
(batch_targets_cxcy > 1)).T
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
(grid_xy > 1)).T
offset_inds = torch.stack(
(torch.ones_like(left), left, up, right, bottom))

batch_targets_scaled = batch_targets_scaled.repeat(
(5, 1, 1))[offset_inds]
retained_gt_inds = match_gt_inds.repeat((5, 1))[offset_inds]
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
1)[offset_inds]

# prepare pred results and positive sample indexes to
# calculate class loss and bbox lo
_chunk_targets = batch_targets_scaled.chunk(4, 1)
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
-1), img_class_inds.long().T

grid_xy_long = (grid_xy -
retained_offsets * self.near_neighbor_thr).long()
grid_x_inds, grid_y_inds = grid_xy_long.T
for k in range(self.num_base_priors):
retained_inds = priors_inds == k
assign_results_prior = {
'stride': self.featmap_strides[i],
'grid_x_inds': grid_x_inds[retained_inds],
'grid_y_inds': grid_y_inds[retained_inds],
'img_inds': img_inds[retained_inds],
'class_inds': class_inds[retained_inds],
'retained_gt_inds': retained_gt_inds[retained_inds],
'prior_ind': k
}
assign_results_feat.append(assign_results_prior)
assign_results.append(assign_results_feat)
return assign_results

def assign(self, batch_data_samples: Union[list, dict],
inputs_hw: Union[tuple, torch.Size]) -> dict:
"""Calculate assigning results. This function is provided to the
`assigner_visualization.py` script.
Args:
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
inputs_hw: Height and width of inputs size
Returns:
dict: A dictionary of assigning components.
"""
if isinstance(batch_data_samples, list):
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs

assign_inputs = (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore, inputs_hw)
else:
# Fast version
assign_inputs = (batch_data_samples['bboxes_labels'],
batch_data_samples['img_metas'], inputs_hw)
assign_results = self.assign_by_gt_and_feat(*assign_inputs)

return assign_results
5 changes: 5 additions & 0 deletions projects/assigner_visualization/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from projects.assigner_visualization.detectors.yolo_detector_assigner import \
YOLODetectorAssigner

__all__ = ['YOLODetectorAssigner']
Loading

0 comments on commit 9ef8831

Please sign in to comment.