Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of : HIC-YOLOv5: Improved YOLOv5 for Small Object Detection #12264

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e7e1cdb
imp of CBAM + Involution at common.py
aash1999 Oct 21, 2023
16fd02c
import CBAm and Involution into yolo.py
aash1999 Oct 21, 2023
7eff0ef
handle GPU err on
aash1999 Oct 21, 2023
b7715ca
Merge pull request #1 from aash1999/cbam-imp
aash1999 Oct 21, 2023
55ea408
added arch. backbone to /models/
aash1999 Oct 21, 2023
02469f2
readme update
aash1999 Oct 21, 2023
a27e8d1
Merge pull request #2 from aash1999/cbam-imp
aash1999 Oct 21, 2023
b1b1ab9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2023
2ee59f6
Update general.py
aash1999 Oct 24, 2023
79112df
Update common.py
aash1999 Oct 24, 2023
3d46323
adding hyp and model files as mentioned in paper
aash1999 Oct 24, 2023
f0b2ffc
Merge pull request #4 from aash1999/cbam-imp
aash1999 Oct 24, 2023
1204c74
Delete models/yolo5m-cbam-involution.yaml
aash1999 Oct 24, 2023
947266a
Update general.py
aash1999 Oct 24, 2023
a56bf81
Update yolov5s-cbam-involution.yaml
aash1999 Oct 25, 2023
5208303
Update CITATION.cff
aash1999 Oct 25, 2023
ccf2664
removed trailing spaces in general.py
aash1999 Oct 25, 2023
16ed93a
yapf formatting
aash1999 Oct 25, 2023
11ddc58
yapf formatting
aash1999 Oct 25, 2023
02bf256
Delete CITATION.cff
aash1999 Oct 25, 2023
1f85ade
reverting the files to commit 4d687c8
aash1999 Oct 25, 2023
8738c27
yapf reformat
aash1999 Oct 25, 2023
0fd8fe3
movig files to where they belong
aash1999 Oct 25, 2023
2fc73ca
typo correction
aash1999 Oct 25, 2023
ad78882
Merge branch 'master' into master
aash1999 Oct 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions data/hyps/hyp.hic-yolov5s.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# hyperparameters for HIC-YOLOv5 for small object detection on VisDrone Dataset
# python train.py --hyp hyp.hic-yolov5s.yaml

lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
box: 0.05 # box loss gain
cls: 0.25 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 0.5 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.4 # image HSV-Hue augmentation (fraction)
hsv_s: 0.3 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.5 # image HSV-Value augmentation (fraction)
degrees: 0.2 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.4 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)s
mixup: 0.2 # image mixup (probability)
copy_paste: 0.1 # segment copy-paste (probability)
162 changes: 162 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,165 @@ def forward(self, x):
if isinstance(x, list):
x = torch.cat(x, 1)
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))


# contributed by @aash1999
class ChannelAttention(nn.Module):

def __init__(self, in_planes, ratio=16):
"""
Initialize the Channel Attention module.

Args:
in_planes (int): Number of input channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Channel Attention module.

Args:
x (torch.Tensor): Input tensor.

Returns:
out (torch.Tensor): Output tensor after applying channel attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out


# contributed by @aash1999
class SpatialAttention(nn.Module):

def __init__(self, kernel_size=7):
"""
Initialize the Spatial Attention module.

Args:
kernel_size (int): Size of the convolutional kernel for spatial attention.
"""
super().__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
"""
Forward pass of the Spatial Attention module.

Args:
x (torch.Tensor): Input tensor.

Returns:
out (torch.Tensor): Output tensor after applying spatial attention.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)


# contributed by @aash1999
class CBAM(nn.Module):
# ch_in, ch_out, shortcut, groups, expansion, ratio, kernel_size
def __init__(self, c1, c2, kernel_size=3, shortcut=True, g=1, e=0.5, ratio=16):
"""
Initialize the CBAM (Convolutional Block Attention Module) .

Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the convolutional kernel.
shortcut (bool): Whether to use a shortcut connection.
g (int): Number of groups for grouped convolutions.
e (float): Expansion factor for hidden channels.
ratio (int): Reduction ratio for the hidden channels in the channel attention block.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.channel_attention = ChannelAttention(c2, ratio)
self.spatial_attention = SpatialAttention(kernel_size)

def forward(self, x):
"""
Forward pass of the CBAM .

Args:
x (torch.Tensor): Input tensor.

Returns:
out (torch.Tensor): Output tensor after applying the CBAM bottleneck.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
x2 = self.cv2(self.cv1(x))
out = self.channel_attention(x2) * x2
out = self.spatial_attention(out) * out
return x + out if self.add else out


# contributed by @aash1999
class Involution(nn.Module):

def __init__(self, c1, c2, kernel_size, stride):
"""
Initialize the Involution module.

Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
kernel_size (int): Size of the involution kernel.
stride (int): Stride for the involution operation.
"""
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.c1 = c1
reduction_ratio = 1
self.group_channels = 16
self.groups = self.c1 // self.group_channels
self.conv1 = Conv(c1, c1 // reduction_ratio, 1)
self.conv2 = Conv(c1 // reduction_ratio, kernel_size ** 2 * self.groups, 1, 1)

if stride > 1:
self.avgpool = nn.AvgPool2d(stride, stride)
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size - 1) // 2, stride)

def forward(self, x):
"""
Forward pass of the Involution module.

Args:
x (torch.Tensor): Input tensor.

Returns:
out (torch.Tensor): Output tensor after applying the involution operation.
"""
with warnings.catch_warnings():
warnings.simplefilter('ignore')
weight = self.conv2(x)
b, c, h, w = weight.shape
weight = weight.view(b, self.groups, self.kernel_size ** 2, h, w).unsqueeze(2)
out = self.unfold(x).view(b, self.groups, self.group_channels, self.kernel_size ** 2, h, w)
out = (weight * out).sum(dim=3).view(b, self.c1, h, w)

return out
60 changes: 60 additions & 0 deletions models/hub/yolov5s-cbam-involution.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 10 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [2.9434,4.0435, 3.8626,8.5592, 6.8534, 5.9391]
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 3, CBAM, [1024, 3]],
[-1, 1, SPPF, [1024, 5]], # 10
]

# YOLOv5 v6.0 head
head:
[[-1, 1, Involution, [1024, 1, 1]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 15

[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [512, False]], # 19

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 2], 1, Concat, [1]],
[-1, 3, C3, [256, False]], # 23 160*160 p2 head

[-1, 1, Conv, [256, 3, 2]],
[[-1, 19], 1, Concat, [1]],
[-1, 3, C3, [512, False]], # 26 80*80 p3 head

[-1, 1, Conv, [256, 3, 2]],
[[-1, 15], 1, Concat, [1]],
[-1, 3, C3, [256, False]], # 29 40*40 p4 head

[-1, 1, Conv, [512, 3, 2]],
[[-1, 11], 1, Concat, [1]],
[-1, 3, C3, [1024, False]], # 32 20*20 p5 head

[[23, 26, 29, 32], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5)
]
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, CBAM, Involution}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
Expand Down
3 changes: 2 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def init_seeds(seed=0, deterministic=False):
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
torch.use_deterministic_algorithms(True)
# since nn.AdaptiveAvgPool2d doesn't have backward implementation during GPU training
torch.use_deterministic_algorithms(False, warn_only=True)
torch.backends.cudnn.deterministic = True
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['PYTHONHASHSEED'] = str(seed)
Expand Down