Skip to content

Commit

Permalink
Feature visualization update (ultralytics#3920)
Browse files Browse the repository at this point in the history
* Feature visualization update

* Save to jpg (faster)

* Save to png
  • Loading branch information
glenn-jocher authored Jul 7, 2021
1 parent 1442d30 commit 7dac0ab
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
6 changes: 5 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
Expand Down Expand Up @@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

# Inference
t1 = time_synchronized()
pred = model(img, augment=augment)[0]
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]

# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand Down Expand Up @@ -201,6 +204,7 @@ def parse_opt():
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
Expand Down
11 changes: 5 additions & 6 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
self.info()
logger.info('')

def forward(self, x, augment=False, profile=False):
def forward(self, x, augment=False, profile=False, visualize=False):
if augment:
return self.forward_augment(x) # augmented inference, None
else:
return self.forward_once(x, profile) # single-scale inference, train
return self.forward_once(x, profile, visualize) # single-scale inference, train

def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
Expand All @@ -136,7 +135,7 @@ def forward_augment(self, x):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False, feature_vis=False):
def forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
Expand All @@ -155,8 +154,8 @@ def forward_once(self, x, profile=False, feature_vis=False):
x = m(x) # run
y.append(x if m.i in self.save else None) # save output

if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)

if profile:
logger.info('%.1fms total' % sum(dt))
Expand Down
39 changes: 18 additions & 21 deletions utils/plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Plotting utils

import glob
import math
import os
from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -15,7 +15,6 @@
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
Expand Down Expand Up @@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(x, module_type, stage, n=64):
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
"""
x: Features to be visualized
module_type: Module type
stage: Module stage within model
n: Maximum number of feature maps to plot
save_dir: Directory to save results
"""
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
if 'Detect' not in module_type:
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename

plt.figure(tight_layout=True)
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')

print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)

0 comments on commit 7dac0ab

Please sign in to comment.