-
Notifications
You must be signed in to change notification settings - Fork 7k
/
yolo_networks.py
2025 lines (1734 loc) · 96.4 KB
/
yolo_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import io
import re
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from warnings import warn
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from ...ops import box_convert
from ..yolo import (
Conv,
CSPSPP,
CSPStage,
ELANStage,
FastSPP,
MaxPool,
RouteLayer,
ShortcutLayer,
YOLOV4Backbone,
YOLOV4TinyBackbone,
YOLOV5Backbone,
YOLOV7Backbone,
)
from .anchor_utils import global_xy
from .target_matching import HighestIoUMatching, IoUThresholdMatching, PRIOR_SHAPES, SimOTAMatching, SizeRatioMatching
from .yolo_loss import YOLOLoss
DARKNET_CONFIG = Dict[str, Any]
CREATE_LAYER_OUTPUT = Tuple[nn.Module, int] # layer, num_outputs
PRED = Dict[str, Tensor]
PREDS = List[PRED] # TorchScript doesn't allow a tuple
TARGET = Dict[str, Tensor]
TARGETS = List[TARGET] # TorchScript doesn't allow a tuple
NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits
class DetectionLayer(nn.Module):
"""A YOLO detection layer.
A YOLO model has usually 1 - 3 detection layers at different resolutions. The loss is summed from all of them.
Args:
num_classes: Number of different classes that this layer predicts.
prior_shapes: A list of prior box dimensions for this layer, used for scaling the predicted dimensions. The list
should contain [width, height] pairs in the network input resolution.
matching_func: The matching algorithm to be used for assigning targets to anchors.
loss_func: ``YOLOLoss`` object for calculating the losses.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the
detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and
height are scaled up so that the maximum value is four times the anchor dimension. This is used by the
Darknet configurations of Scaled-YOLOv4.
"""
def __init__(
self,
num_classes: int,
prior_shapes: PRIOR_SHAPES,
matching_func: Callable,
loss_func: YOLOLoss,
xy_scale: float = 1.0,
input_is_normalized: bool = False,
) -> None:
super().__init__()
self.num_classes = num_classes
self.prior_shapes = prior_shapes
self.matching_func = matching_func
self.loss_func = loss_func
self.xy_scale = xy_scale
self.input_is_normalized = input_is_normalized
def forward(self, x: Tensor, image_size: Tensor) -> Tuple[Tensor, PREDS]:
"""Runs a forward pass through this YOLO detection layer.
Maps cell-local coordinates to global coordinates in the image space, scales the bounding boxes with the
anchors, converts the center coordinates to corner coordinates, and maps probabilities to the `]0, 1[` range
using sigmoid.
If targets are given, computes also losses from the predictions and the targets. This layer is responsible only
for the targets that best match one of the anchors assigned to this layer. Training losses will be saved to the
``losses`` attribute. ``hits`` attribute will be set to the number of targets that this layer was responsible
for. ``losses`` is a tensor of three elements: the overlap, confidence, and classification loss.
Args:
x: The output from the previous layer. The size of this tensor has to be
``[batch_size, anchors_per_cell * (num_classes + 5), height, width]``.
image_size: Image width and height in a vector (defines the scale of the predicted and target coordinates).
Returns:
The layer output, with normalized probabilities, in a tensor sized
``[batch_size, anchors_per_cell * height * width, num_classes + 5]`` and a list of dictionaries, containing
the same predictions, but with unnormalized probabilities (for loss calculation).
"""
batch_size, num_features, height, width = x.shape
num_attrs = self.num_classes + 5
anchors_per_cell = num_features // num_attrs
if anchors_per_cell != len(self.prior_shapes):
raise ValueError(
"The model predicts {} bounding boxes per spatial location, but {} prior box dimensions are defined "
"for this layer.".format(anchors_per_cell, len(self.prior_shapes))
)
# Reshape the output to have the bounding box attributes of each grid cell on its own row.
x = x.permute(0, 2, 3, 1) # [batch_size, height, width, anchors_per_cell * num_attrs]
x = x.view(batch_size, height, width, anchors_per_cell, num_attrs)
# Take the sigmoid of the bounding box coordinates, confidence score, and class probabilities, unless the input
# is normalized by the previous layer activation. Confidence and class losses use the unnormalized values if
# possible.
norm_x = x if self.input_is_normalized else torch.sigmoid(x)
xy = norm_x[..., :2]
wh = x[..., 2:4]
confidence = x[..., 4]
classprob = x[..., 5:]
norm_confidence = norm_x[..., 4]
norm_classprob = norm_x[..., 5:]
# Eliminate grid sensitivity. The previous layer should output extremely high values for the sigmoid to produce
# x/y coordinates close to one. YOLOv4 solves this by scaling the x/y coordinates.
xy = xy * self.xy_scale - 0.5 * (self.xy_scale - 1)
image_xy = global_xy(xy, image_size)
prior_shapes = torch.tensor(self.prior_shapes, dtype=wh.dtype, device=wh.device)
if self.input_is_normalized:
image_wh = 4 * torch.square(wh) * prior_shapes
else:
image_wh = torch.exp(wh) * prior_shapes
box = torch.cat((image_xy, image_wh), -1)
box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy")
output = torch.cat((box, norm_confidence.unsqueeze(-1), norm_classprob), -1)
output = output.reshape(batch_size, height * width * anchors_per_cell, num_attrs)
# It's better to use binary_cross_entropy_with_logits() for loss computation, so we'll provide the unnormalized
# confidence and classprob, when available.
preds = [{"boxes": b, "confidences": c, "classprobs": p} for b, c, p in zip(box, confidence, classprob)]
return output, preds
def match_targets(
self,
preds: PREDS,
return_preds: PREDS,
targets: TARGETS,
image_size: Tensor,
) -> Tuple[PRED, TARGET]:
"""Matches the predictions to targets.
Args:
preds: List of predictions for each image, as returned by the ``forward()`` method of this layer. These will
be matched to the training targets.
return_preds: List of predictions for each image. The matched predictions will be returned from this list.
When calculating the auxiliary loss for deep supervision, predictions from a different layer are used
for loss computation.
targets: List of training targets for each image.
image_size: Width and height in a vector that defines the scale of the target coordinates.
Returns:
Two dictionaries, the matched predictions and targets.
"""
batch_size = len(preds)
if (len(targets) != batch_size) or (len(return_preds) != batch_size):
raise ValueError("Different batch size for predictions and targets.")
# Creating lists that are concatenated in the end will confuse TorchScript compilation. Instead, we'll create
# tensors and concatenate new matches immediately.
pred_boxes = torch.empty((0, 4), device=return_preds[0]["boxes"].device)
pred_confidences = torch.empty(0, device=return_preds[0]["confidences"].device)
pred_bg_confidences = torch.empty(0, device=return_preds[0]["confidences"].device)
pred_classprobs = torch.empty((0, self.num_classes), device=return_preds[0]["classprobs"].device)
target_boxes = torch.empty((0, 4), device=targets[0]["boxes"].device)
target_labels = torch.empty(0, dtype=torch.int64, device=targets[0]["labels"].device)
for image_preds, image_return_preds, image_targets in zip(preds, return_preds, targets):
if image_targets["boxes"].shape[0] > 0:
pred_selector, background_selector, target_selector = self.matching_func(
image_preds, image_targets, image_size
)
pred_boxes = torch.cat((pred_boxes, image_return_preds["boxes"][pred_selector]))
pred_confidences = torch.cat((pred_confidences, image_return_preds["confidences"][pred_selector]))
pred_bg_confidences = torch.cat(
(pred_bg_confidences, image_return_preds["confidences"][background_selector])
)
pred_classprobs = torch.cat((pred_classprobs, image_return_preds["classprobs"][pred_selector]))
target_boxes = torch.cat((target_boxes, image_targets["boxes"][target_selector]))
target_labels = torch.cat((target_labels, image_targets["labels"][target_selector]))
else:
pred_bg_confidences = torch.cat((pred_bg_confidences, image_return_preds["confidences"].flatten()))
matched_preds = {
"boxes": pred_boxes,
"confidences": pred_confidences,
"bg_confidences": pred_bg_confidences,
"classprobs": pred_classprobs,
}
matched_targets = {
"boxes": target_boxes,
"labels": target_labels,
}
return matched_preds, matched_targets
def calculate_losses(
self,
preds: PREDS,
targets: TARGETS,
image_size: Tensor,
loss_preds: Optional[PREDS] = None,
) -> Tuple[Tensor, int]:
"""Matches the predictions to targets and computes the losses.
Args:
preds: List of predictions for each image, as returned by ``forward()``. These will be matched to the
training targets and used to compute the losses (unless another set of predictions for loss computation
is given in ``loss_preds``).
targets: List of training targets for each image.
image_size: Width and height in a vector that defines the scale of the target coordinates.
loss_preds: List of predictions for each image. If given, these will be used for loss computation, instead
of the same predictions that were used for matching. This is needed for deep supervision in YOLOv7.
Returns:
A vector of the overlap, confidence, and classification loss, normalized by batch size, and the number of
targets that were matched to this layer.
"""
if loss_preds is None:
loss_preds = preds
matched_preds, matched_targets = self.match_targets(preds, loss_preds, targets, image_size)
losses = self.loss_func.elementwise_sums(matched_preds, matched_targets, self.input_is_normalized, image_size)
losses = torch.stack((losses.overlap, losses.confidence, losses.classification)) / len(preds)
hits = len(matched_targets["boxes"])
return losses, hits
def create_detection_layer(
prior_shapes: PRIOR_SHAPES,
prior_shape_idxs: List[int],
matching_algorithm: Optional[str] = None,
matching_threshold: Optional[float] = None,
spatial_range: float = 5.0,
size_range: float = 4.0,
ignore_bg_threshold: float = 0.7,
overlap_func: str = "ciou",
predict_overlap: Optional[float] = None,
label_smoothing: Optional[float] = None,
overlap_loss_multiplier: float = 5.0,
confidence_loss_multiplier: float = 1.0,
class_loss_multiplier: float = 1.0,
**kwargs: Any,
) -> DetectionLayer:
"""Creates a detection layer module and the required loss function and target matching objects.
Args:
prior_shapes: A list of all the prior box dimensions, used for scaling the predicted dimensions and possibly for
matching the targets to the anchors. The list should contain [width, height] pairs in the network input
resolution.
prior_shape_idxs: List of indices to ``prior_shapes`` that is used to select the (usually 3) prior shapes that
this layer uses.
matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule
from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given
ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that
gives the highest IoU, default).
matching_threshold: Threshold for "size" and "iou" matching algorithms.
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target, where `N` is the value of this parameter.
size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and
no less than `1/N` times the target dimensions, where `N` is the value of this parameter.
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the corresponding anchor
has IoU with some target greater than this threshold, the predictor will not be taken into account when
calculating the confidence loss.
overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou",
"giou", "diou", and "ciou".
predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target
confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of
``overlap_func``.
label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary
targets), and 1.0 means that the target probabilities are always 0.5.
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
num_classes: Number of different classes that this layer predicts.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
input_is_normalized: The input is normalized by logistic activation in the previous layer. In this case the
detection layer will not take the sigmoid of the coordinate and probability predictions, and the width and
height are scaled up so that the maximum value is four times the anchor dimension. This is used by the
Darknet configurations of Scaled-YOLOv4.
"""
matching_func: Callable
if matching_algorithm == "simota":
loss_func = YOLOLoss(
overlap_func, None, None, overlap_loss_multiplier, confidence_loss_multiplier, class_loss_multiplier
)
matching_func = SimOTAMatching(prior_shapes, prior_shape_idxs, loss_func, spatial_range, size_range)
elif matching_algorithm == "size":
if matching_threshold is None:
raise ValueError("matching_threshold is required with size ratio matching.")
matching_func = SizeRatioMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold)
elif matching_algorithm == "iou":
if matching_threshold is None:
raise ValueError("matching_threshold is required with IoU threshold matching.")
matching_func = IoUThresholdMatching(prior_shapes, prior_shape_idxs, matching_threshold, ignore_bg_threshold)
elif matching_algorithm == "maxiou" or matching_algorithm is None:
matching_func = HighestIoUMatching(prior_shapes, prior_shape_idxs, ignore_bg_threshold)
else:
raise ValueError(f"Matching algorithm `{matching_algorithm}´ is unknown.")
loss_func = YOLOLoss(
overlap_func,
predict_overlap,
label_smoothing,
overlap_loss_multiplier,
confidence_loss_multiplier,
class_loss_multiplier,
)
layer_shapes = [prior_shapes[i] for i in prior_shape_idxs]
return DetectionLayer(prior_shapes=layer_shapes, matching_func=matching_func, loss_func=loss_func, **kwargs)
class DetectionStage(nn.Module):
"""This is a convenience class for running a detection layer.
It might be cleaner to implement this as a function, but TorchScript allows only specific types in function
arguments, not modules.
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__()
self.detection_layer = create_detection_layer(**kwargs)
def forward(
self,
layer_input: Tensor,
targets: Optional[TARGETS],
image_size: Tensor,
detections: List[Tensor],
losses: List[Tensor],
hits: List[int],
) -> None:
"""Runs the detection layer on the inputs and appends the output to the ``detections`` list.
If ``targets`` is given, also calculates the losses and appends to the ``losses`` list.
Args:
layer_input: Input to the detection layer.
targets: List of training targets for each image.
image_size: Width and height in a vector that defines the scale of the target coordinates.
detections: A list where a tensor containing the detections will be appended to.
losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given.
hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is
given.
"""
output, preds = self.detection_layer(layer_input, image_size)
detections.append(output)
if targets is not None:
layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size)
losses.append(layer_losses)
hits.append(layer_hits)
class DetectionStageWithAux(nn.Module):
"""This class represents a combination of a lead and an auxiliary detection layer.
Args:
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target. This parameter specifies `N` for the lead head.
aux_spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target. This parameter specifies `N` for the auxiliary head.
aux_weight: Weight for the loss from the auxiliary head.
"""
def __init__(
self, spatial_range: float = 5.0, aux_spatial_range: float = 3.0, aux_weight: float = 0.25, **kwargs: Any
) -> None:
super().__init__()
self.detection_layer = create_detection_layer(spatial_range=spatial_range, **kwargs)
self.aux_detection_layer = create_detection_layer(spatial_range=aux_spatial_range, **kwargs)
self.aux_weight = aux_weight
def forward(
self,
layer_input: Tensor,
aux_input: Tensor,
targets: Optional[TARGETS],
image_size: Tensor,
detections: List[Tensor],
losses: List[Tensor],
hits: List[int],
) -> None:
"""Runs the detection layer and the auxiliary detection layer on their respective inputs and appends the
outputs to the ``detections`` list.
If ``targets`` is given, also calculates the losses and appends to the ``losses`` list.
Args:
layer_input: Input to the lead detection layer.
aux_input: Input to the auxiliary detection layer.
targets: List of training targets for each image.
image_size: Width and height in a vector that defines the scale of the target coordinates.
detections: A list where a tensor containing the detections will be appended to.
losses: A list where a tensor containing the losses will be appended to, if ``targets`` is given.
hits: A list where the number of targets that matched this layer will be appended to, if ``targets`` is
given.
"""
output, preds = self.detection_layer(layer_input, image_size)
detections.append(output)
if targets is not None:
# Match lead head predictions to targets and calculate losses from lead head outputs.
layer_losses, layer_hits = self.detection_layer.calculate_losses(preds, targets, image_size)
losses.append(layer_losses)
hits.append(layer_hits)
# Match lead head predictions to targets and calculate losses from auxiliary head outputs.
_, aux_preds = self.aux_detection_layer(aux_input, image_size)
layer_losses, layer_hits = self.aux_detection_layer.calculate_losses(
preds, targets, image_size, loss_preds=aux_preds
)
losses.append(layer_losses * self.aux_weight)
hits.append(layer_hits)
@torch.jit.script
def get_image_size(images: Tensor) -> Tensor:
"""Get the image size from an input tensor.
The function needs the ``@torch.jit.script`` decorator in order for ONNX generation to work. The tracing based
generator will loose track of e.g. ``images.shape[1]`` and treat it as a Python variable and not a tensor. This will
cause the dimension to be treated as a constant in the model, which prevents dynamic input sizes.
Args:
images: An image batch to take the width and height from.
Returns:
A tensor that contains the image width and height.
"""
height = images.shape[2]
width = images.shape[3]
return torch.tensor([width, height], device=images.device)
class YOLOV4TinyNetwork(nn.Module):
"""The "tiny" network architecture from YOLOv4.
Args:
num_classes: Number of different classes that this model predicts.
backbone: A backbone network that returns the output from each stage.
width: The number of channels in the narrowest convolutional layer. The wider convolutional layers will use a
number of channels that is a multiple of this value.
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for
matching the targets to the anchors. The list should contain [width, height] pairs in the network input
resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are
assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that
you typically want to sort the shapes from the smallest to the largest.
matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule
from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given
ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that
gives the highest IoU, default).
matching_threshold: Threshold for "size" and "iou" matching algorithms.
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target, where `N` is the value of this parameter.
size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and
no less than `1/N` times the target dimensions, where `N` is the value of this parameter.
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou",
"giou", "diou", and "ciou".
predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target
confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of
``overlap_func``.
label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary
targets), and 1.0 means that the target probabilities are always 0.5.
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""
def __init__(
self,
num_classes: int,
backbone: Optional[nn.Module] = None,
width: int = 32,
activation: Optional[str] = "leaky",
normalization: Optional[str] = "batchnorm",
prior_shapes: Optional[PRIOR_SHAPES] = None,
**kwargs: Any,
) -> None:
super().__init__()
# By default use the prior shapes that have been learned from the COCO data.
if prior_shapes is None:
prior_shapes = [
[12, 16],
[19, 36],
[40, 28],
[36, 75],
[76, 55],
[72, 146],
[142, 110],
[192, 243],
[459, 401],
]
anchors_per_cell = 3
else:
anchors_per_cell, modulo = divmod(len(prior_shapes), 3)
if modulo != 0:
raise ValueError("The number of provided prior shapes needs to be divisible by 3.")
num_outputs = (5 + num_classes) * anchors_per_cell
def conv(in_channels: int, out_channels: int, kernel_size: int = 1) -> nn.Module:
return Conv(in_channels, out_channels, kernel_size, stride=1, activation=activation, norm=normalization)
def upsample(in_channels: int, out_channels: int) -> nn.Module:
channels = conv(in_channels, out_channels)
upsample = nn.Upsample(scale_factor=2, mode="nearest")
return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)]))
def outputs(in_channels: int) -> nn.Module:
return nn.Conv2d(in_channels, num_outputs, kernel_size=1, stride=1, bias=True)
def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage:
assert prior_shapes is not None
return DetectionStage(
prior_shapes=prior_shapes,
prior_shape_idxs=list(prior_shape_idxs),
num_classes=num_classes,
input_is_normalized=False,
**kwargs,
)
self.backbone = backbone or YOLOV4TinyBackbone(width=width, activation=activation, normalization=normalization)
self.fpn5 = conv(width * 16, width * 8)
self.out5 = nn.Sequential(
OrderedDict(
[
("channels", conv(width * 8, width * 16)),
(f"outputs_{num_outputs}", outputs(width * 16)),
]
)
)
self.upsample5 = upsample(width * 8, width * 4)
self.fpn4 = conv(width * 12, width * 8, kernel_size=3)
self.out4 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 8))]))
self.upsample4 = upsample(width * 8, width * 2)
self.fpn3 = conv(width * 6, width * 4, kernel_size=3)
self.out3 = nn.Sequential(OrderedDict([(f"outputs_{num_outputs}", outputs(width * 4))]))
self.detect3 = detect([0, 1, 2])
self.detect4 = detect([3, 4, 5])
self.detect5 = detect([6, 7, 8])
def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT:
detections: List[Tensor] = [] # Outputs from detection layers
losses: List[Tensor] = [] # Losses from detection layers
hits: List[int] = [] # Number of targets each detection layer was responsible for
image_size = get_image_size(x)
c3, c4, c5 = self.backbone(x)[-3:]
p5 = self.fpn5(c5)
x = torch.cat((self.upsample5(p5), c4), dim=1)
p4 = self.fpn4(x)
x = torch.cat((self.upsample4(p4), c3), dim=1)
p3 = self.fpn3(x)
self.detect5(self.out5(p5), targets, image_size, detections, losses, hits)
self.detect4(self.out4(p4), targets, image_size, detections, losses, hits)
self.detect3(self.out3(p3), targets, image_size, detections, losses, hits)
return detections, losses, hits
class YOLOV4Network(nn.Module):
"""Network architecture that corresponds approximately to the Cross Stage Partial Network from YOLOv4.
Args:
num_classes: Number of different classes that this model predicts.
backbone: A backbone network that returns the output from each stage.
widths: Number of channels at each network stage.
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for
matching the targets to the anchors. The list should contain [width, height] pairs in the network input
resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are
assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that
you typically want to sort the shapes from the smallest to the largest.
matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule
from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given
ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that
gives the highest IoU, default).
matching_threshold: Threshold for "size" and "iou" matching algorithms.
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target, where `N` is the value of this parameter.
size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and
no less than `1/N` times the target dimensions, where `N` is the value of this parameter.
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou",
"giou", "diou", and "ciou".
predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target
confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of
``overlap_func``.
label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary
targets), and 1.0 means that the target probabilities are always 0.5.
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""
def __init__(
self,
num_classes: int,
backbone: Optional[nn.Module] = None,
widths: Sequence[int] = (32, 64, 128, 256, 512, 1024),
activation: Optional[str] = "silu",
normalization: Optional[str] = "batchnorm",
prior_shapes: Optional[PRIOR_SHAPES] = None,
**kwargs: Any,
) -> None:
super().__init__()
# By default use the prior shapes that have been learned from the COCO data.
if prior_shapes is None:
prior_shapes = [
[12, 16],
[19, 36],
[40, 28],
[36, 75],
[76, 55],
[72, 146],
[142, 110],
[192, 243],
[459, 401],
]
anchors_per_cell = 3
else:
anchors_per_cell, modulo = divmod(len(prior_shapes), 3)
if modulo != 0:
raise ValueError("The number of provided prior shapes needs to be divisible by 3.")
num_outputs = (5 + num_classes) * anchors_per_cell
def spp(in_channels: int, out_channels: int) -> nn.Module:
return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization)
def conv(in_channels: int, out_channels: int) -> nn.Module:
return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization)
def csp(in_channels: int, out_channels: int) -> nn.Module:
return CSPStage(
in_channels,
out_channels,
depth=2,
shortcut=False,
norm=normalization,
activation=activation,
)
def out(in_channels: int) -> nn.Module:
conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization)
outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1)
return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)]))
def upsample(in_channels: int, out_channels: int) -> nn.Module:
channels = conv(in_channels, out_channels)
upsample = nn.Upsample(scale_factor=2, mode="nearest")
return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)]))
def downsample(in_channels: int, out_channels: int) -> nn.Module:
return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization)
def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage:
assert prior_shapes is not None
return DetectionStage(
prior_shapes=prior_shapes,
prior_shape_idxs=list(prior_shape_idxs),
num_classes=num_classes,
input_is_normalized=False,
**kwargs,
)
if backbone is not None:
self.backbone = backbone
else:
self.backbone = YOLOV4Backbone(widths=widths, activation=activation, normalization=normalization)
w3 = widths[-3]
w4 = widths[-2]
w5 = widths[-1]
self.spp = spp(w5, w5)
self.pre4 = conv(w4, w4 // 2)
self.upsample5 = upsample(w5, w4 // 2)
self.fpn4 = csp(w4, w4)
self.pre3 = conv(w3, w3 // 2)
self.upsample4 = upsample(w4, w3 // 2)
self.fpn3 = csp(w3, w3)
self.downsample3 = downsample(w3, w3)
self.pan4 = csp(w3 + w4, w4)
self.downsample4 = downsample(w4, w4)
self.pan5 = csp(w4 + w5, w5)
self.out3 = out(w3)
self.out4 = out(w4)
self.out5 = out(w5)
self.detect3 = detect(range(0, anchors_per_cell))
self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2))
self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3))
def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT:
detections: List[Tensor] = [] # Outputs from detection layers
losses: List[Tensor] = [] # Losses from detection layers
hits: List[int] = [] # Number of targets each detection layer was responsible for
image_size = get_image_size(x)
c3, c4, x = self.backbone(x)[-3:]
c5 = self.spp(x)
x = torch.cat((self.upsample5(c5), self.pre4(c4)), dim=1)
p4 = self.fpn4(x)
x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1)
n3 = self.fpn3(x)
x = torch.cat((self.downsample3(n3), p4), dim=1)
n4 = self.pan4(x)
x = torch.cat((self.downsample4(n4), c5), dim=1)
n5 = self.pan5(x)
self.detect3(self.out3(n3), targets, image_size, detections, losses, hits)
self.detect4(self.out4(n4), targets, image_size, detections, losses, hits)
self.detect5(self.out5(n5), targets, image_size, detections, losses, hits)
return detections, losses, hits
class YOLOV4P6Network(nn.Module):
"""Network architecture that corresponds approximately to the variant of YOLOv4 with four detection layers.
Args:
num_classes: Number of different classes that this model predicts.
backbone: A backbone network that returns the output from each stage.
widths: Number of channels at each network stage.
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for
matching the targets to the anchors. The list should contain [width, height] pairs in the network input
resolution. There should be `4N` pairs, where `N` is the number of anchors per spatial location. They are
assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that
you typically want to sort the shapes from the smallest to the largest.
matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule
from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given
ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that
gives the highest IoU, default).
matching_threshold: Threshold for "size" and "iou" matching algorithms.
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target, where `N` is the value of this parameter.
size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and
no less than `1/N` times the target dimensions, where `N` is the value of this parameter.
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou",
"giou", "diou", and "ciou".
predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target
confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of
``overlap_func``.
label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary
targets), and 1.0 means that the target probabilities are always 0.5.
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""
def __init__(
self,
num_classes: int,
backbone: Optional[nn.Module] = None,
widths: Sequence[int] = (32, 64, 128, 256, 512, 1024, 1024),
activation: Optional[str] = "silu",
normalization: Optional[str] = "batchnorm",
prior_shapes: Optional[PRIOR_SHAPES] = None,
**kwargs: Any,
) -> None:
super().__init__()
# By default use the prior shapes that have been learned from the COCO data.
if prior_shapes is None:
prior_shapes = [
[13, 17],
[31, 25],
[24, 51],
[61, 45],
[61, 45],
[48, 102],
[119, 96],
[97, 189],
[97, 189],
[217, 184],
[171, 384],
[324, 451],
[324, 451],
[545, 357],
[616, 618],
[1024, 1024],
]
anchors_per_cell = 4
else:
anchors_per_cell, modulo = divmod(len(prior_shapes), 4)
if modulo != 0:
raise ValueError("The number of provided prior shapes needs to be divisible by 4.")
num_outputs = (5 + num_classes) * anchors_per_cell
def spp(in_channels: int, out_channels: int) -> nn.Module:
return CSPSPP(in_channels, out_channels, activation=activation, norm=normalization)
def conv(in_channels: int, out_channels: int) -> nn.Module:
return Conv(in_channels, out_channels, kernel_size=1, stride=1, activation=activation, norm=normalization)
def csp(in_channels: int, out_channels: int) -> nn.Module:
return CSPStage(
in_channels,
out_channels,
depth=2,
shortcut=False,
norm=normalization,
activation=activation,
)
def out(in_channels: int) -> nn.Module:
conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, activation=activation, norm=normalization)
outputs = nn.Conv2d(in_channels, num_outputs, kernel_size=1)
return nn.Sequential(OrderedDict([("conv", conv), (f"outputs_{num_outputs}", outputs)]))
def upsample(in_channels: int, out_channels: int) -> nn.Module:
channels = conv(in_channels, out_channels)
upsample = nn.Upsample(scale_factor=2, mode="nearest")
return nn.Sequential(OrderedDict([("channels", channels), ("upsample", upsample)]))
def downsample(in_channels: int, out_channels: int) -> nn.Module:
return Conv(in_channels, out_channels, kernel_size=3, stride=2, activation=activation, norm=normalization)
def detect(prior_shape_idxs: Sequence[int]) -> DetectionStage:
assert prior_shapes is not None
return DetectionStage(
prior_shapes=prior_shapes,
prior_shape_idxs=list(prior_shape_idxs),
num_classes=num_classes,
input_is_normalized=False,
**kwargs,
)
if backbone is not None:
self.backbone = backbone
else:
self.backbone = YOLOV4Backbone(
widths=widths, depths=(1, 1, 3, 15, 15, 7, 7), activation=activation, normalization=normalization
)
w3 = widths[-4]
w4 = widths[-3]
w5 = widths[-2]
w6 = widths[-1]
self.spp = spp(w6, w6)
self.pre5 = conv(w5, w5 // 2)
self.upsample6 = upsample(w6, w5 // 2)
self.fpn5 = csp(w5, w5)
self.pre4 = conv(w4, w4 // 2)
self.upsample5 = upsample(w5, w4 // 2)
self.fpn4 = csp(w4, w4)
self.pre3 = conv(w3, w3 // 2)
self.upsample4 = upsample(w4, w3 // 2)
self.fpn3 = csp(w3, w3)
self.downsample3 = downsample(w3, w3)
self.pan4 = csp(w3 + w4, w4)
self.downsample4 = downsample(w4, w4)
self.pan5 = csp(w4 + w5, w5)
self.downsample5 = downsample(w5, w5)
self.pan6 = csp(w5 + w6, w6)
self.out3 = out(w3)
self.out4 = out(w4)
self.out5 = out(w5)
self.out6 = out(w6)
self.detect3 = detect(range(0, anchors_per_cell))
self.detect4 = detect(range(anchors_per_cell, anchors_per_cell * 2))
self.detect5 = detect(range(anchors_per_cell * 2, anchors_per_cell * 3))
self.detect6 = detect(range(anchors_per_cell * 3, anchors_per_cell * 4))
def forward(self, x: Tensor, targets: Optional[TARGETS] = None) -> NETWORK_OUTPUT:
detections: List[Tensor] = [] # Outputs from detection layers
losses: List[Tensor] = [] # Losses from detection layers
hits: List[int] = [] # Number of targets each detection layer was responsible for
image_size = get_image_size(x)
c3, c4, c5, x = self.backbone(x)[-4:]
c6 = self.spp(x)
x = torch.cat((self.upsample6(c6), self.pre5(c5)), dim=1)
p5 = self.fpn5(x)
x = torch.cat((self.upsample5(p5), self.pre4(c4)), dim=1)
p4 = self.fpn4(x)
x = torch.cat((self.upsample4(p4), self.pre3(c3)), dim=1)
n3 = self.fpn3(x)
x = torch.cat((self.downsample3(n3), p4), dim=1)
n4 = self.pan4(x)
x = torch.cat((self.downsample4(n4), p5), dim=1)
n5 = self.pan5(x)
x = torch.cat((self.downsample5(n5), c6), dim=1)
n6 = self.pan6(x)
self.detect3(self.out3(n3), targets, image_size, detections, losses, hits)
self.detect4(self.out4(n4), targets, image_size, detections, losses, hits)
self.detect5(self.out5(n5), targets, image_size, detections, losses, hits)
self.detect6(self.out6(n6), targets, image_size, detections, losses, hits)
return detections, losses, hits
class YOLOV5Network(nn.Module):
"""The YOLOv5 network architecture. Different variants (n/s/m/l/x) can be achieved by adjusting the ``depth``
and ``width`` parameters.
Args:
num_classes: Number of different classes that this model predicts.
backbone: A backbone network that returns the output from each stage.
width: Number of channels in the narrowest convolutional layer. The wider convolutional layers will use a number
of channels that is a multiple of this value. The values used by the different variants are 16 (yolov5n), 32
(yolov5s), 48 (yolov5m), 64 (yolov5l), and 80 (yolov5x).
depth: Repeat the bottleneck layers this many times. Can be used to make the network deeper. The values used by
the different variants are 1 (yolov5n, yolov5s), 2 (yolov5m), 3 (yolov5l), and 4 (yolov5x).
activation: Which layer activation to use. Can be "relu", "leaky", "mish", "silu" (or "swish"), "logistic",
"linear", or "none".
normalization: Which layer normalization to use. Can be "batchnorm", "groupnorm", or "none".
prior_shapes: A list of prior box dimensions, used for scaling the predicted dimensions and possibly for
matching the targets to the anchors. The list should contain [width, height] pairs in the network input
resolution. There should be `3N` pairs, where `N` is the number of anchors per spatial location. They are
assigned to the layers from the lowest (high-resolution) to the highest (low-resolution) layer, meaning that
you typically want to sort the shapes from the smallest to the largest.
matching_algorithm: Which algorithm to use for matching targets to anchors. "simota" (the SimOTA matching rule
from YOLOX), "size" (match those prior shapes, whose width and height relative to the target is below given
ratio), "iou" (match all prior shapes that give a high enough IoU), or "maxiou" (match the prior shape that
gives the highest IoU, default).
matching_threshold: Threshold for "size" and "iou" matching algorithms.
spatial_range: The "simota" matching algorithm will restrict to anchors that are within an `N x N` grid cell
area centered at the target, where `N` is the value of this parameter.
size_range: The "simota" matching algorithm will restrict to anchors whose dimensions are no more than `N` and
no less than `1/N` times the target dimensions, where `N` is the value of this parameter.
ignore_bg_threshold: If a predictor is not responsible for predicting any target, but the prior shape has IoU
with some target greater than this threshold, the predictor will not be taken into account when calculating
the confidence loss.
overlap_func: Which function to use for calculating the IoU between two sets of boxes. Valid values are "iou",
"giou", "diou", and "ciou".
predict_overlap: Balance between binary confidence targets and predicting the overlap. 0.0 means that the target
confidence is 1 if there's an object, and 1.0 means that the target confidence is the output of
``overlap_func``.
label_smoothing: The epsilon parameter (weight) for class label smoothing. 0.0 means no smoothing (binary
targets), and 1.0 means that the target probabilities are always 0.5.
overlap_loss_multiplier: Overlap loss will be scaled by this value.
confidence_loss_multiplier: Confidence loss will be scaled by this value.
class_loss_multiplier: Classification loss will be scaled by this value.
xy_scale: Eliminate "grid sensitivity" by scaling the box coordinates by this factor. Using a value > 1.0 helps
to produce coordinate values close to one.
"""
def __init__(
self,