Skip to content

Commit

Permalink
some fix for yolo5
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed Jan 29, 2021
1 parent b8babc1 commit 5e7dcc1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
28 changes: 20 additions & 8 deletions yolo5/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,20 @@ def yolo5_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0,

# gains for box, class and confidence loss
# from https://github.com/ultralytics/yolov5/blob/master/data/hyp.scratch.yaml
box_loss_gain = 0.05
class_loss_gain = 0.5
#box_loss_gain = 0.05
#class_loss_gain = 0.5
#confidence_loss_gain = 1.0

box_loss_gain = 1.0
class_loss_gain = 1.0
confidence_loss_gain = 1.0

# balance weights for confidence (objectness) loss
# on different predict heads (x/32, x/16, x/8),
# here the order is reversed from ultralytics PyTorch version
# from https://github.com/ultralytics/yolov5/blob/master/utils/loss.py#L109
confidence_balance_weights = [0.4, 1.0, 4.0]
#confidence_balance_weights = [0.4, 1.0, 4.0]
confidence_balance_weights = [1.0, 1.0, 1.0]

if num_layers == 3:
anchor_mask = [[6,7,8], [3,4,5], [0,1,2]]
Expand Down Expand Up @@ -312,7 +317,7 @@ def yolo5_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0,
raw_true_xy = y_true[i][..., :2]*grid_shapes[i][::-1] - grid
raw_true_wh = K.log(y_true[i][..., 2:4] / anchors[anchor_mask[i]] * input_shape[::-1])
raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh)) # avoid log(0)=-inf
#box_loss_scale = 2 - y_true[i][...,2:3]*y_true[i][...,3:4]
box_loss_scale = 2 - y_true[i][...,2:3]*y_true[i][...,3:4]

# Find ignore mask, iterate over each of batch.
#ignore_mask = tf.TensorArray(K.dtype(y_true[0]), size=1, dynamic_size=True)
Expand All @@ -331,14 +336,14 @@ def yolo5_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0,
# Calculate GIoU loss as location loss
raw_true_box = y_true[i][...,0:4]
giou = box_giou(raw_true_box, pred_box)
giou_loss = object_mask * (1 - giou)
giou_loss = object_mask * box_loss_scale * (1 - giou)
location_loss = giou_loss
iou = giou
elif use_diou_loss:
# Calculate DIoU loss as location loss
raw_true_box = y_true[i][...,0:4]
diou = box_diou(raw_true_box, pred_box)
diou_loss = object_mask * (1 - diou)
diou_loss = object_mask * box_loss_scale * (1 - diou)
location_loss = diou_loss
iou = diou
else:
Expand All @@ -351,10 +356,10 @@ def yolo5_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0,
#wh_loss = K.sum(wh_loss) / batch_size_f
#location_loss = xy_loss + wh_loss

# use box iou for positive sample as objectness ground truth,
# use box iou for positive sample as objectness ground truth (need to detach gradient),
# to calculate confidence loss
# from https://github.com/ultralytics/yolov5/blob/master/utils/loss.py#L127
true_objectness_probs = K.maximum(iou, 0)
true_objectness_probs = tf.stop_gradient(K.maximum(iou, 0))

if use_focal_obj_loss:
# Focal loss for objectness confidence
Expand Down Expand Up @@ -383,6 +388,13 @@ def yolo5_loss(args, anchors, num_classes, ignore_thresh=.5, label_smoothing=0,
class_loss = class_loss_gain * K.sum(class_loss) / batch_size_f
location_loss = box_loss_gain * K.sum(location_loss) / batch_size_f

#object_number = K.sum(object_mask)
#divided_factor = K.switch(object_number > 1, object_number, K.constant(1, K.dtype(object_number)))

#confidence_loss = confidence_loss_gain * K.mean(confidence_loss)
#class_loss = class_loss_gain * K.sum(class_loss) / divided_factor
#location_loss = box_loss_gain * K.sum(location_loss) / divided_factor

loss += location_loss + confidence_loss + class_loss
total_location_loss += location_loss
total_confidence_loss += confidence_loss
Expand Down
18 changes: 11 additions & 7 deletions yolo5/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,24 @@ def bottleneck_csp_c3_lite_block(x, num_filters, num_blocks, depth_multiple, wid
return DarknetConv2D_BN_Swish(num_filters, (1,1))(x)


def make_yolo5_spp_neck(x, num_filters):
def yolo5_spp_neck(x, num_filters):
'''Conv2D_BN_Swish layer followed by a SPP_Conv block'''
x = DarknetConv2D_BN_Swish(num_filters//2, (1,1))(x)
x = Spp_Conv2D_BN_Swish(x, num_filters)

return x


def yolo5_predictions(feature_maps, feature_channel_nums, num_anchors, num_classes, depth_multiple, width_multiple):
def yolo5_predictions(feature_maps, feature_channel_nums, num_anchors, num_classes, depth_multiple, width_multiple, with_spp=True):
f1, f2, f3 = feature_maps
f1_channel_num, f2_channel_num, f3_channel_num = feature_channel_nums

# SPP & BottleneckCSP block, in ultralytics PyTorch version
# they're defined in backbone
x1 = make_yolo5_spp_neck(f1, f1_channel_num)
x1 = bottleneck_csp_block(x1, f1_channel_num, 3, depth_multiple, width_multiple, shortcut=False)
if with_spp:
f1 = yolo5_spp_neck(f1, f1_channel_num)

x1 = bottleneck_csp_block(f1, f1_channel_num, 3, depth_multiple, width_multiple, shortcut=False)

#feature map 1 head (19x19 for 608 input)
x1 = DarknetConv2D_BN_Swish(f2_channel_num, (1,1))(x1)
Expand Down Expand Up @@ -234,14 +236,16 @@ def yolo5_predictions(feature_maps, feature_channel_nums, num_anchors, num_class
return y1, y2, y3


def yolo5lite_predictions(feature_maps, feature_channel_nums, num_anchors, num_classes, depth_multiple, width_multiple):
def yolo5lite_predictions(feature_maps, feature_channel_nums, num_anchors, num_classes, depth_multiple, width_multiple, with_spp=True):
f1, f2, f3 = feature_maps
f1_channel_num, f2_channel_num, f3_channel_num = feature_channel_nums

# SPP & BottleneckCSP block, in ultralytics PyTorch version
# they're defined in backbone
x1 = make_yolo5_spp_neck(f1, f1_channel_num)
x1 = bottleneck_csp_lite_block(x1, f1_channel_num, 3, depth_multiple, width_multiple, shortcut=False, block_id_str='pred_1')
if with_spp:
f1 = yolo5_spp_neck(f1, f1_channel_num)

x1 = bottleneck_csp_lite_block(f1, f1_channel_num, 3, depth_multiple, width_multiple, shortcut=False, block_id_str='pred_1')

#feature map 1 head (19x19 for 608 input)
x1 = DarknetConv2D_BN_Swish(f2_channel_num, (1,1))(x1)
Expand Down
21 changes: 14 additions & 7 deletions yolo5/models/yolo5_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet import MobileNet

from yolo5.models.layers import yolo5_predictions, yolo5lite_predictions
from yolo5.models.layers import yolo5_predictions, yolo5lite_predictions, yolo5_spp_neck


def yolo5_mobilenet_body(inputs, num_anchors, num_classes, alpha=1.0):
Expand All @@ -26,18 +26,22 @@ def yolo5_mobilenet_body(inputs, num_anchors, num_classes, alpha=1.0):
# f3: 52 x 52 x (256*alpha) for 416 input
f3 = mobilenet.get_layer('conv_pw_5_relu').output

# use yolo5_small depth_multiple, and alpha as width_multiple
# add SPP neck with original channel number
f1 = yolo5_spp_neck(f1, int(1024*alpha))

# use yolo5_small depth_multiple and width_multiple for head
depth_multiple = 0.33
width_multiple = alpha
width_multiple = 0.5

f1_channel_num = int(1024*width_multiple)
f2_channel_num = int(512*width_multiple)
f3_channel_num = int(256*width_multiple)

y1, y2, y3 = yolo5_predictions((f1, f2, f3), (f1_channel_num, f2_channel_num, f3_channel_num), num_anchors, num_classes, depth_multiple, width_multiple)
y1, y2, y3 = yolo5_predictions((f1, f2, f3), (f1_channel_num, f2_channel_num, f3_channel_num), num_anchors, num_classes, depth_multiple, width_multiple, with_spp=False)

return Model(inputs, [y1, y2, y3])


def yolo5lite_mobilenet_body(inputs, num_anchors, num_classes, alpha=1.0):
"""Create YOLO_V5 Lite MobileNet model CNN body in Keras."""
mobilenet = MobileNet(input_tensor=inputs, weights='imagenet', include_top=False, alpha=alpha)
Expand All @@ -55,15 +59,18 @@ def yolo5lite_mobilenet_body(inputs, num_anchors, num_classes, alpha=1.0):
# f3: 52 x 52 x (256*alpha) for 416 input
f3 = mobilenet.get_layer('conv_pw_5_relu').output

# use yolo5_small depth_multiple, and alpha as width_multiple
# add SPP neck with original channel number
f1 = yolo5_spp_neck(f1, int(1024*alpha))

# use yolo5_small depth_multiple and width_multiple for head
depth_multiple = 0.33
width_multiple = alpha
width_multiple = 0.5

f1_channel_num = int(1024*width_multiple)
f2_channel_num = int(512*width_multiple)
f3_channel_num = int(256*width_multiple)

y1, y2, y3 = yolo5lite_predictions((f1, f2, f3), (f1_channel_num, f2_channel_num, f3_channel_num), num_anchors, num_classes, depth_multiple, width_multiple)
y1, y2, y3 = yolo5lite_predictions((f1, f2, f3), (f1_channel_num, f2_channel_num, f3_channel_num), num_anchors, num_classes, depth_multiple, width_multiple, with_spp=False)

return Model(inputs, [y1, y2, y3])

0 comments on commit 5e7dcc1

Please sign in to comment.