Skip to content

jvlmdr/flax-cifar

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This is intended to be a flax version of the pytorch-cifar repo that I've often used as a reference.

The ResNet models have 'imagenet' and 'cifar' variants, which are intended for 224x224 and 32x32 images respectively.

Credit

The WideResNet and ResNetV2 implementations were based on the objax.zoo package.

The ResNetV1 implementation was based on the flax imagenet example.

The DenseNet and VGG implementations were based on the pytorch-cifar repo.

Results

All results using default config (see configs/default.py).

Refer to the wandb project.

Model Params Acc.
VGG-11 backbone 9,228,362 90.6%
VGG-13 backbone 9,413,066 92.7%
VGG-16 backbone 14,724,042 92.1%
VGG-19 backbone 20,035,018 92.3%
VGG-11 28,154,954 89.7%
VGG-13 28,339,658 91.9%
VGG-16 33,650,634 92.2%
VGG-19 38,961,610 90.3%
ResNetV1-18 11,173,962 94.1%
ResNetV1-50 23,520,842 94.0%
ResNetV2-18 11,172,170 94.7%
ResNetV2-50 23,513,162 93.8%
WideResNet-28-2 1,467,610 93.3%
WideResNet-28-8 23,354,842 95.1%
DenseNet121-12 1,000,618 93.6%
DenseNet121-32 6,956,298 94.8%
DenseNet169-32 12,493,322 94.2%

Usage

python main.py --dataset_root=path/to/cifar10 --config=configs/default.py --config.arch=resnet_v1_18

Dependencies

absl-py
flax
jax[cuda]
jaxopt
ml_collections
numpy 
torch
torchvision
tqdm
wandb

To install jax with CUDA, refer to the installation instructions.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages