Skip to content

Commit

Permalink
fix #820
Browse files Browse the repository at this point in the history
  • Loading branch information
YahooKID authored and Chilicyy committed May 11, 2023
1 parent 5a67f6a commit 6e58d1b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 25 deletions.
21 changes: 15 additions & 6 deletions yolov6/models/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,20 @@ def __init__(
block=RepVGGBlock,
csp_e=float(1)/2,
fuse_P2=False,
cspsppf=False
cspsppf=False,
stage_block_type="BepC3"
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None

if stage_block_type == "BepC3":
stage_block = BepC3
elif stage_block_type == "MBLABlock":
stage_block = MBLABlock
else:
raise NotImplementedError

self.fuse_P2 = fuse_P2

self.stem = block(
Expand All @@ -408,7 +417,7 @@ def __init__(
kernel_size=3,
stride=2
),
BepC3(
stage_block(
in_channels=channels_list[1],
out_channels=channels_list[1],
n=num_repeats[1],
Expand All @@ -424,7 +433,7 @@ def __init__(
kernel_size=3,
stride=2
),
BepC3(
stage_block(
in_channels=channels_list[2],
out_channels=channels_list[2],
n=num_repeats[2],
Expand All @@ -440,7 +449,7 @@ def __init__(
kernel_size=3,
stride=2
),
BepC3(
stage_block(
in_channels=channels_list[3],
out_channels=channels_list[3],
n=num_repeats[3],
Expand All @@ -460,7 +469,7 @@ def __init__(
kernel_size=3,
stride=2,
),
BepC3(
stage_block(
in_channels=channels_list[4],
out_channels=channels_list[4],
n=num_repeats[4],
Expand All @@ -475,7 +484,7 @@ def __init__(
kernel_size=3,
stride=2,
),
BepC3(
stage_block(
in_channels=channels_list[5],
out_channels=channels_list[5],
n=num_repeats[5],
Expand Down
62 changes: 43 additions & 19 deletions yolov6/models/reppan.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,38 +551,46 @@ def __init__(
channels_list=None,
num_repeats=None,
block=BottleRep,
csp_e=float(1)/2
csp_e=float(1)/2,
stage_block_type="BepC3"
):
super().__init__()

if stage_block_type == "BepC3":
stage_block = BepC3
elif stage_block_type == "MBLABlock":
stage_block = MBLABlock
else:
raise NotImplementedError

assert channels_list is not None
assert num_repeats is not None

self.Rep_p4 = BepC3(
self.Rep_p4 = stage_block(
in_channels=channels_list[3] + channels_list[5], # 512 + 256
out_channels=channels_list[5], # 256
n=num_repeats[5],
e=csp_e,
block=block
)

self.Rep_p3 = BepC3(
self.Rep_p3 = stage_block(
in_channels=channels_list[2] + channels_list[6], # 256 + 128
out_channels=channels_list[6], # 128
n=num_repeats[6],
e=csp_e,
block=block
)

self.Rep_n3 = BepC3(
self.Rep_n3 = stage_block(
in_channels=channels_list[6] + channels_list[7], # 128 + 128
out_channels=channels_list[8], # 256
n=num_repeats[7],
e=csp_e,
block=block
)

self.Rep_n4 = BepC3(
self.Rep_n4 = stage_block(
in_channels=channels_list[5] + channels_list[9], # 256 + 256
out_channels=channels_list[10], # 512
n=num_repeats[8],
Expand Down Expand Up @@ -787,13 +795,21 @@ def __init__(
channels_list=None,
num_repeats=None,
block=BottleRep,
csp_e=float(1)/2
csp_e=float(1)/2,
stage_block_type="BepC3"
):
super().__init__()

assert channels_list is not None
assert num_repeats is not None

if stage_block_type == "BepC3":
stage_block = BepC3
elif stage_block_type == "MBLABlock":
stage_block = MBLABlock
else:
raise NotImplementedError

self.reduce_layer0 = ConvBNReLU(
in_channels=channels_list[5], # 1024
out_channels=channels_list[6], # 512
Expand All @@ -806,7 +822,7 @@ def __init__(
out_channels=channels_list[6], # 512
)

self.Rep_p5 = BepC3(
self.Rep_p5 = stage_block(
in_channels=channels_list[4] + channels_list[6], # 768 + 512
out_channels=channels_list[6], # 512
n=num_repeats[6],
Expand All @@ -826,7 +842,7 @@ def __init__(
out_channels=channels_list[7] # 256
)

self.Rep_p4 = BepC3(
self.Rep_p4 = stage_block(
in_channels=channels_list[3] + channels_list[7], # 512 + 256
out_channels=channels_list[7], # 256
n=num_repeats[7],
Expand All @@ -846,7 +862,7 @@ def __init__(
out_channels=channels_list[8] # 128
)

self.Rep_p3 = BepC3(
self.Rep_p3 = stage_block(
in_channels=channels_list[2] + channels_list[8], # 256 + 128
out_channels=channels_list[8], # 128
n=num_repeats[8],
Expand All @@ -861,7 +877,7 @@ def __init__(
stride=2
)

self.Rep_n4 = BepC3(
self.Rep_n4 = stage_block(
in_channels=channels_list[8] + channels_list[8], # 128 + 128
out_channels=channels_list[9], # 256
n=num_repeats[9],
Expand All @@ -876,7 +892,7 @@ def __init__(
stride=2
)

self.Rep_n5 = BepC3(
self.Rep_n5 = stage_block(
in_channels=channels_list[7] + channels_list[9], # 256 + 256
out_channels=channels_list[10], # 512
n=num_repeats[10],
Expand All @@ -891,7 +907,7 @@ def __init__(
stride=2
)

self.Rep_n6 = BepC3(
self.Rep_n6 = stage_block(
in_channels=channels_list[6] + channels_list[10], # 512 + 512
out_channels=channels_list[11], # 1024
n=num_repeats[11],
Expand Down Expand Up @@ -946,13 +962,21 @@ def __init__(
channels_list=None,
num_repeats=None,
block=BottleRep,
csp_e=float(1)/2
csp_e=float(1)/2,
stage_block_type="BepC3"
):
super().__init__()

assert channels_list is not None
assert num_repeats is not None

if stage_block_type == "BepC3":
stage_block = BepC3
elif stage_block_type == "MBLABlock":
stage_block = MBLABlock
else:
raise NotImplementedError

self.reduce_layer0 = ConvBNReLU(
in_channels=channels_list[5], # 1024
out_channels=channels_list[6], # 512
Expand All @@ -965,7 +989,7 @@ def __init__(
out_channels=channels_list[6], # 512
)

self.Rep_p5 = BepC3(
self.Rep_p5 = stage_block(
in_channels=channels_list[6], # 512
out_channels=channels_list[6], # 512
n=num_repeats[6],
Expand All @@ -985,7 +1009,7 @@ def __init__(
out_channels=channels_list[7], # 256
)

self.Rep_p4 = BepC3(
self.Rep_p4 = stage_block(
in_channels=channels_list[7], # 256
out_channels=channels_list[7], # 256
n=num_repeats[7],
Expand All @@ -1005,7 +1029,7 @@ def __init__(
out_channels=channels_list[8], # 128
)

self.Rep_p3 = BepC3(
self.Rep_p3 = stage_block(
in_channels=channels_list[8], # 128
out_channels=channels_list[8], # 128
n=num_repeats[8],
Expand All @@ -1020,7 +1044,7 @@ def __init__(
stride=2
)

self.Rep_n4 = BepC3(
self.Rep_n4 = stage_block(
in_channels=channels_list[8] + channels_list[8], # 128 + 128
out_channels=channels_list[9], # 256
n=num_repeats[9],
Expand All @@ -1035,7 +1059,7 @@ def __init__(
stride=2
)

self.Rep_n5 = BepC3(
self.Rep_n5 = stage_block(
in_channels=channels_list[7] + channels_list[9], # 256 + 256
out_channels=channels_list[10], # 512
n=num_repeats[10],
Expand All @@ -1050,7 +1074,7 @@ def __init__(
stride=2
)

self.Rep_n6 = BepC3(
self.Rep_n6 = stage_block(
in_channels=channels_list[6] + channels_list[10], # 512 + 512
out_channels=channels_list[11], # 1024
n=num_repeats[11],
Expand Down

0 comments on commit 6e58d1b

Please sign in to comment.