Skip to content

Commit

Permalink
Merge branch 'release/2.6' of https://github.com/PaddlePaddle/PaddleYOLO
Browse files Browse the repository at this point in the history
 into release/2.6
  • Loading branch information
nemonameless committed May 25, 2023
2 parents 43cf8c1 + cb3c659 commit bcc4404
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 4 deletions.
3 changes: 2 additions & 1 deletion ppdet/engine/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
}

TO_STATIC_SPEC = {
'yolov5_l_300e_coco': None,
'yolov3_darknet53_270e_coco': [{
'im_id': paddle.static.InputSpec(
name='im_id', shape=[-1, 1], dtype='float32'),
Expand Down Expand Up @@ -170,7 +171,7 @@ def _dump_infer_config(config, path, image_shape, model):
infer_arch = 'PPYOLOE'

if infer_arch in [
'YOLOX', 'YOLOF', 'PPYOLOE', 'YOLOv5', 'YOLOv6', 'YOLOv7', 'YOLOv8'
'YOLOX', 'YOLOF', 'PPYOLOE', 'YOLOv5', 'YOLOv6', 'YOLOv7', 'YOLOv8'
]:
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
Expand Down
5 changes: 4 additions & 1 deletion ppdet/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def __init__(self, cfg, mode='train'):
self.custom_white_list = self.cfg.get('custom_white_list', None)
self.custom_black_list = self.cfg.get('custom_black_list', None)

if self.cfg.architecture in ['RTMDet', 'YOLOv6'] and self.mode == 'train':
if self.cfg.architecture in ['RTMDet', 'YOLOv6'
] and self.mode == 'train':
raise NotImplementedError('{} training not supported yet.'.format(
self.cfg.architecture))
if 'slim' in cfg and cfg['slim_type'] == 'PTQ':
Expand Down Expand Up @@ -316,6 +317,8 @@ def train(self, validate=False):
model = self.model
if self.cfg.get('to_static', False):
model = apply_to_static(self.cfg, model)
if self.cfg.architecture == 'YOLOv5':
model.yolo_head.loss.to_static = True
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
(self.cfg.use_gpu or self.cfg.use_mlu) and self._nranks > 1)
if sync_bn:
Expand Down
117 changes: 115 additions & 2 deletions ppdet/modeling/losses/yolov5_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self,
],
dtype=np.float32) * bias # offsets
self.anchor_t = anchor_t
self.to_static = False

def build_targets(self, outputs, targets, anchors):
if 0:
Expand Down Expand Up @@ -206,8 +207,21 @@ def yolov5_loss(self, pi, t_cls, t_box, t_indices, t_anchor, balance):

# Classification
if self.num_classes > 1: # cls loss (only if multiple classes)
# t = paddle.full_like(ps[:, 5:], self.cls_neg_label)
# t[range(n), t_cls] = self.cls_pos_label
# loss_cls = self.BCEcls(ps[:, 5:], t)

t = paddle.full_like(ps[:, 5:], self.cls_neg_label)
t[range(n), t_cls] = self.cls_pos_label
if not self.to_static:
t = paddle.put_along_axis(
t,
t_cls.unsqueeze(-1),
values=self.cls_pos_label,
axis=1)
else:
for i in range(n):
t[i, t_cls[i]] = self.cls_pos_label

loss_cls = self.BCEcls(ps[:, 5:], t)

obji = self.BCEobj(pi[:, :, :, :, 4], tobj) # [bs, 3, h, w]
Expand All @@ -221,7 +235,12 @@ def yolov5_loss(self, pi, t_cls, t_box, t_indices, t_anchor, balance):

def forward(self, inputs, targets, anchors):
yolo_losses = dict()
tcls, tbox, indices, anch = self.build_targets(inputs, targets, anchors)
if not self.to_static:
tcls, tbox, indices, anch = self.build_targets(inputs, targets,
anchors)
else:
tcls, tbox, indices, anch = self.build_targets_paddle(
inputs, targets, anchors)

for i, (p_det, balance) in enumerate(zip(inputs, self.balance)):
t_cls = tcls[i]
Expand Down Expand Up @@ -250,3 +269,97 @@ def forward(self, inputs, targets, anchors):
loss += yolo_losses[k]
yolo_losses['loss'] = loss
return yolo_losses

def build_targets_paddle(self, outputs, targets, anchors):
# targets['gt_class'] [bs, max_gt_nums, 1]
# targets['gt_bbox'] [bs, max_gt_nums, 4]
# targets['pad_gt_mask'] [bs, max_gt_nums, 1]
gt_nums = [len(bbox) for bbox in targets['gt_bbox']]
nt = int(sum(gt_nums))
anchors = anchors
na = anchors.shape[1] # not len(anchors)
tcls, tbox, indices, anch = [], [], [], []

gain = paddle.ones(
[7], dtype=paddle.float32) # normalized to gridspace gain
ai = paddle.tile(
paddle.arange(
na, dtype=paddle.float32).reshape([na, 1]), [1, nt])

batch_size = outputs[0].shape[0]
gt_labels = []
for i, (
gt_num, gt_bboxs, gt_classes
) in enumerate(zip(gt_nums, targets['gt_bbox'], targets['gt_class'])):
if gt_num == 0:
continue
gt_bbox = gt_bboxs[:gt_num].astype('float32')
gt_class = (gt_classes[:gt_num] * 1.0).astype('float32')
img_idx = paddle.repeat_interleave(
paddle.to_tensor([i]), gt_num,
axis=0)[None, :].astype('float32').T

gt_labels.append(
paddle.concat(
(img_idx, gt_class, gt_bbox), axis=-1))

if (len(gt_labels)):
gt_labels = paddle.concat(gt_labels)
else:
gt_labels = paddle.zeros([0, 6], dtype=paddle.float32)

targets_labels = paddle.concat((paddle.tile(
paddle.unsqueeze(gt_labels, 0), [na, 1, 1]), ai[:, :, None]), 2)
g = self.bias # 0.5

for i in range(len(anchors)):
anchor = anchors[i] / self.downsample_ratios[i]
gain[2:6] = paddle.to_tensor(
outputs[i].shape,
dtype=paddle.float32)[[3, 2, 3, 2]] # xyxy gain

# Match targets_labels to
t = targets_labels * gain
if nt:
# Matches
r = t[:, :, 4:6] / anchor[:, None]
j = paddle.maximum(r, 1 / r).max(2) < self.anchor_t
t = paddle.flatten(t, 0, 1)
j = paddle.flatten(j.astype(paddle.int32), 0,
1).astype(paddle.bool)
t = t[j] # filter

# Offsets
gxy = t[:, 2:4] # grid xy
gxi = gain[[2, 3]] - gxy # inverse
j, k = ((gxy % 1 < g) & (gxy > 1)).T.astype(paddle.int64)
l, m = ((gxi % 1 < g) & (gxi > 1)).T.astype(paddle.int64)
j = paddle.flatten(
paddle.stack((paddle.ones_like(j), j, k, l, m)), 0,
1).astype(paddle.bool)
t = paddle.flatten(paddle.tile(t, [5, 1, 1]), 0, 1)
t = t[j]
offsets = paddle.zeros_like(gxy)[None, :] + paddle.to_tensor(
self.off)[:, None]
offsets = paddle.flatten(offsets, 0, 1)[j]
else:
t = targets_labels[0]
offsets = 0

# Define
b, c = t[:, :2].astype(paddle.int64).T # image, class
gxy = t[:, 2:4] # grid xy
gwh = t[:, 4:6] # grid wh
gij = (gxy - offsets).astype(paddle.int64)
gi, gj = gij.T # grid xy indices

# Append
a = t[:, 6].astype(paddle.int64) # anchor indices
gj, gi = gj.clip(0, gain[3] - 1), gi.clip(0, gain[2] - 1)
indices.append(
(b, a, gj.astype(paddle.int64), gi.astype(paddle.int64)))
tbox.append(
paddle.concat((gxy - gij, gwh), 1).astype(paddle.float32))
anch.append(anchor[a])
tcls.append(c)
return tcls, tbox, indices, anch

0 comments on commit bcc4404

Please sign in to comment.