Skip to content

Commit b294ed9

Browse files
authored
Add new checkpoint
Hello, I think you miss one checkpoint given that endpoints['reduction_5'] is the head of the network but not the last layer of the backbone. This may be problematic if we use you implementation of EfficientNet as backbone of EfficientDet. In this PR, I let the checkpoint of the head (endpoints['reduction_6']) but changed endpoints['reduction_5'] accordingly. If I'm wrong let me know. Regards, Renaud
1 parent 3d400a5 commit b294ed9

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

efficientnet_pytorch/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ def extract_endpoints(self, inputs):
248248
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
249249
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
250250
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
251-
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
251+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
252+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
252253
"""
253254
endpoints = dict()
254255

@@ -264,6 +265,8 @@ def extract_endpoints(self, inputs):
264265
x = block(x, drop_connect_rate=drop_connect_rate)
265266
if prev_x.size(2) > x.size(2):
266267
endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
268+
elif idx == len(self._blocks) - 1:
269+
endpoints['reduction_{}'.format(len(endpoints)+1)] = x
267270
prev_x = x
268271

269272
# Head

0 commit comments

Comments
 (0)