Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into norm_norm_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 26, 2022
2 parents abc9ba2 + 07379c6 commit 95cfc9b
Show file tree
Hide file tree
Showing 53 changed files with 7,862 additions and 4,586 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor

## What's New

### Jan 14, 2022
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
* `mnasnet_small` - 65.6 top-1
* `mobilenetv2_050` - 65.9
* `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
* `semnasnet_075` - 73
* `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture

### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
Expand Down Expand Up @@ -255,10 +267,12 @@ All model architecture families include variants with pretrained weights. There
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).

* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
* BEiT - https://arxiv.org/abs/2106.08254
* Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
Expand All @@ -276,19 +290,23 @@ A full version of the list below with source links can be found in the [document
* MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
* MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
* Halo Nets - https://arxiv.org/abs/2103.12731
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
* HRNet - https://arxiv.org/abs/1908.07919
* Inception-V3 - https://arxiv.org/abs/1512.00567
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* FBNet-V3 - https://arxiv.org/abs/2006.02049
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
* LCNet - https://arxiv.org/abs/2109.15099
* NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559
Expand All @@ -314,6 +332,7 @@ A full version of the list below with source links can be found in the [document
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
* Visformer - https://arxiv.org/abs/2104.12533
* Vision Transformer - https://arxiv.org/abs/2010.11929
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
* Xception - https://arxiv.org/abs/1610.02357
Expand Down
42 changes: 23 additions & 19 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
from timm.utils import setup_default_logging, set_jit_fuser


has_apex = False
Expand Down Expand Up @@ -95,7 +95,8 @@
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')

parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")


# train optimizer parameters
Expand Down Expand Up @@ -186,14 +187,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress

if fuser:
set_jit_fuser(fuser)
self.model = create_model(
model_name,
num_classes=kwargs.pop('num_classes', None),
Expand Down Expand Up @@ -311,10 +314,7 @@ def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
self.model.train()

if kwargs.pop('smoothing', 0) > 0:
self.loss = nn.CrossEntropyLoss().to(self.device)
else:
self.loss = nn.CrossEntropyLoss().to(self.device)
self.loss = nn.CrossEntropyLoss().to(self.device)
self.target_shape = tuple()

self.optimizer = create_optimizer_v2(
Expand Down Expand Up @@ -477,20 +477,21 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
batch_size = initial_batch_size
results = dict()
error_str = 'Unknown'
while batch_size >= 1:
torch.cuda.empty_cache()
try:
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
results = bench.run()
return results
except RuntimeError as e:
e_str = str(e)
print(e_str)
if 'channels_last' in e_str:
print(f'Error: {model_name} not supported in channels_last, skipping.')
error_str = str(e)
if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.')
break
print(f'Error: "{e_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
batch_size = decay_batch_exp(batch_size)
results['error'] = error_str
return results


Expand Down Expand Up @@ -532,13 +533,14 @@ def benchmark(args):
model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns):
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
if prefix:
if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results)
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
model_results.setdefault('param_count', param_count)
model_results.pop('train_param_count', 0)
return model_results if model_results['param_count'] else dict()
if 'error' not in model_results:
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
model_results.setdefault('param_count', param_count)
model_results.pop('train_param_count', 0)
return model_results


def main():
Expand Down Expand Up @@ -582,13 +584,15 @@ def main():
sort_key = 'train_samples_per_sec'
elif 'profile' in args.bench:
sort_key = 'infer_gmacs'
results = filter(lambda x: sort_key in x, results)
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
write_results(results_file, results)
else:
results = benchmark(args)
json_str = json.dumps(results, indent=4)
print(json_str)

# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')


def write_results(results_file, results):
Expand Down
10 changes: 5 additions & 5 deletions docs/training_hparam_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
## EfficientNet-B2 with RandAugment - 80.4 top-1, 95.1 top-5
These params are for dual Titan RTX cards with NVIDIA Apex installed:

`./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016`
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016`

## MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
This params are for dual Titan RTX cards with NVIDIA Apex installed:

`./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce`
`./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce`

## SE-ResNeXt-26-D and SE-ResNeXt-26-T
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
Expand All @@ -21,7 +21,7 @@ The training of this model started with the same command line as EfficientNet-B2
## EfficientNet-B0 with RandAugment - 77.7 top-1, 95.3 top-5
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.

`./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048`
`./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048`

## ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5

Expand All @@ -32,11 +32,11 @@ Trained on two older 1080Ti cards, this took a while. Only slightly, non statist
## EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.

`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`

## MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5

`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`


## ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
Expand Down
Loading

0 comments on commit 95cfc9b

Please sign in to comment.