Skip to content

Commit

Permalink
Merge pull request #1919 from ChengpengChen/main
Browse files Browse the repository at this point in the history
Add RepGhost models and weights
  • Loading branch information
rwightman authored Aug 19, 2023
2 parents b801156 + 69e0ca2 commit 7c2728c
Show file tree
Hide file tree
Showing 3 changed files with 481 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_model_default_cfgs(model_name, batch_size):
outputs = model.forward_features(input_tensor)
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[feat_axis] == model.num_features

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
Expand All @@ -188,16 +188,16 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

# check classifier name matches default_cfg
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .pnasnet import *
from .pvt_v2 import *
from .regnet import *
from .repghost import *
from .repvit import *
from .res2net import *
from .resnest import *
Expand Down
Loading

0 comments on commit 7c2728c

Please sign in to comment.