Skip to content

Commit

Permalink
Add extra presets for GL ResNet for CF
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Jul 8, 2019
1 parent 19e1a3c commit f7847d0
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
3 changes: 3 additions & 0 deletions gluon/gluoncv2/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@
'resnet164bn_cifar10': resnet164bn_cifar10,
'resnet164bn_cifar100': resnet164bn_cifar100,
'resnet164bn_svhn': resnet164bn_svhn,
'resnet272bn_cifar10': resnet272bn_cifar10,
'resnet272bn_cifar100': resnet272bn_cifar100,
'resnet272bn_svhn': resnet272bn_svhn,
'resnet1001_cifar10': resnet1001_cifar10,
'resnet1001_cifar100': resnet1001_cifar100,
'resnet1001_svhn': resnet1001_svhn,
Expand Down
68 changes: 66 additions & 2 deletions gluon/gluoncv2/models/resnet_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

__all__ = ['CIFARResNet', 'resnet20_cifar10', 'resnet20_cifar100', 'resnet20_svhn', 'resnet56_cifar10',
'resnet56_cifar100', 'resnet56_svhn', 'resnet110_cifar10', 'resnet110_cifar100', 'resnet110_svhn',
'resnet164bn_cifar10', 'resnet164bn_cifar100', 'resnet164bn_svhn', 'resnet1001_cifar10',
'resnet1001_cifar100', 'resnet1001_svhn', 'resnet1202_cifar10', 'resnet1202_cifar100', 'resnet1202_svhn']
'resnet164bn_cifar10', 'resnet164bn_cifar100', 'resnet164bn_svhn', 'resnet272bn_cifar10',
'resnet272bn_cifar100', 'resnet272bn_svhn', 'resnet1001_cifar10', 'resnet1001_cifar100', 'resnet1001_svhn',
'resnet1202_cifar10', 'resnet1202_cifar100', 'resnet1202_svhn']

import os
from mxnet import cpu
Expand Down Expand Up @@ -372,6 +373,63 @@ def resnet164bn_svhn(classes=10, **kwargs):
return get_resnet_cifar(classes=classes, blocks=164, bottleneck=True, model_name="resnet164bn_svhn", **kwargs)


def resnet272bn_cifar10(classes=10, **kwargs):
"""
ResNet-272(BN) model for CIFAR-10 from 'Deep Residual Learning for Image Recognition,'
https://arxiv.org/abs/1512.03385.
Parameters:
----------
classes : int, default 10
Number of classification classes.
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_resnet_cifar(classes=classes, blocks=272, bottleneck=True, model_name="resnet272bn_cifar10", **kwargs)


def resnet272bn_cifar100(classes=100, **kwargs):
"""
ResNet-272(BN) model for CIFAR-100 from 'Deep Residual Learning for Image Recognition,'
https://arxiv.org/abs/1512.03385.
Parameters:
----------
classes : int, default 100
Number of classification classes.
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_resnet_cifar(classes=classes, blocks=272, bottleneck=True, model_name="resnet272bn_cifar100", **kwargs)


def resnet272bn_svhn(classes=10, **kwargs):
"""
ResNet-272(BN) model for SVHN from 'Deep Residual Learning for Image Recognition,'
https://arxiv.org/abs/1512.03385.
Parameters:
----------
classes : int, default 10
Number of classification classes.
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_resnet_cifar(classes=classes, blocks=272, bottleneck=True, model_name="resnet272bn_svhn", **kwargs)


def resnet1001_cifar10(classes=10, **kwargs):
"""
ResNet-1001 model for CIFAR-10 from 'Deep Residual Learning for Image Recognition,'
Expand Down Expand Up @@ -505,6 +563,9 @@ def _test():
(resnet164bn_cifar10, 10),
(resnet164bn_cifar100, 100),
(resnet164bn_svhn, 10),
(resnet272bn_cifar10, 10),
(resnet272bn_cifar100, 100),
(resnet272bn_svhn, 10),
(resnet1001_cifar10, 10),
(resnet1001_cifar100, 100),
(resnet1001_svhn, 10),
Expand Down Expand Up @@ -541,6 +602,9 @@ def _test():
assert (model != resnet164bn_cifar10 or weight_count == 1704154)
assert (model != resnet164bn_cifar100 or weight_count == 1727284)
assert (model != resnet164bn_svhn or weight_count == 1704154)
assert (model != resnet272bn_cifar10 or weight_count == 2816986)
assert (model != resnet272bn_cifar100 or weight_count == 2840116)
assert (model != resnet272bn_svhn or weight_count == 2816986)
assert (model != resnet1001_cifar10 or weight_count == 10328602)
assert (model != resnet1001_cifar100 or weight_count == 10351732)
assert (model != resnet1001_svhn or weight_count == 10328602)
Expand Down

0 comments on commit f7847d0

Please sign in to comment.