Unofficial pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence," NeurIPS'20.
This implementation can reproduce the results (CIFAR10 & CIFAR100), which are reported in the paper.
In addition, it includes trained models with semi-supervised and fully supervised manners (download them on below links).
- python 3.6
- pytorch 1.6.0
- torchvision 0.7.0
- tensorboard 2.3.0
- pillow
In addition to the results of semi-supervised learning in the paper, we also attach extra results of fully supervised learning (50000 labels, sup only) + consistency regularization (50000 labels, sup+consistency).
Consistency regularization also improves the classification accuracy, even though the labels are fully provided.
Evaluation is conducted by EMA (exponential moving average) of models in the SGD training trajectory.
#Labels | 40 | 250 | 4000 | sup + consistency | sup only |
---|---|---|---|---|---|
Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 | - | - |
kekmodel | - | - | 94.72 | - | - |
valencebond | 89.63(85.65) | 93.08 | 94.72 | - | - |
Ours | 87.11 | 94.61 | 95.62 | 96.86 | 94.98 |
Trained Moels | checkpoint | checkpoint | checkpoint | checkpoint | checkpoint |
#Labels | 400 | 2500 | 10000 | sup + consistency | sup only |
---|---|---|---|---|---|
Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 | - | - |
kekmodel | - | - | - | - | - |
valencebond | 53.74 | 67.3169 | 73.26 | - | - |
Ours | 48.96 | 71.50 | 78.27 | 83.86 | 80.57 |
Trained Moels | checkpoint | checkpoint | checkpoint | checkpoint | checkpoint |
In the case of CIFAR100@40, the result does not reach the paper's result and is out of the confidence interval.
Despite the result, the accuracy with a small amount of labels highly depends on the label selection and other hyperparameters.
For example, we find that changing the momentum of batch normalization can give better results, closed to the reported accuracies.
In here, we attached some google drive links, which includes training logs and the trained models.
Because of security issues of google drive,
you may fail to download each checkpoint in the result tables by curl/wget.
Then, use gdown to download without the issues.
All checkpoints are included in this directory
After unzip the checkpoints into your own path, you can run
python eval.py --load_path saved_models/cifar10_400/model_best.pth --dataset cifar10 --num_classes 10
For the detailed explanations of arguments, see here.
- In training, the model is saved at
os.path.join(args.save_dir, args.save_name)
, after making new directory. If there already exists the path, the code will raise an error to prevent overwriting of trained models by mistake. If you want to overwrite the files, give--overwrite
. - By default, FixMatch uses hard (one-hot) pseudo labels. If you want to use soft pseudo labels and sharping (T), give
--hard_label False
. Also, you can adjust the sharping parameters--T (YOUR_OWN_VALUE)
. - This code assumes 1 epoch of training, but the number of iterations is 2**20.
- If you restart the training, use
--resume --load_path [YOUR_CHECKPOINT_PATH]
. Then, the checkpoint is loaded to the model, and continues to training from the ceased iteration. see here and the related method. - We set the number of workers for
DataLoader
when distributed training with a single node having V100 GPUs x 4 is used. - If you change the confidence threshold to generate masks in consistency regularization, change
--p_cutoff
. - With 4 GPUs, for the fast update, running statistics of BN is not gathered in distributed training. However, a larger number of GPUs with the same batch size might affect overall accuracies. Then, you can 1) replace BN to syncBN (see here) or 2) use
torch.distributed.all_reduce
for BN buffers before this line. - We checked that syncBN slightly improves accuracies, but the training time is much increased. Thus, this code doesn't include it.
python train.py --rank 0 --gpu [0/1/...] @@@other args@@@
python train.py --world-size 1 --rank 0 @@@other args@@@
When you use multi-GPUs, we strongly recommend using distributed training (even with a single node) for high performance.
With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).
- single node
python train.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@
- multiple nodes (assuming two nodes)
# at node 0
python train.py --world-size 2 --rank 0 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
# at node 1
python train.py --world-size 2 --rank 1 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 10000 --save_name cifar100_10000 --dataset cifar100 --num_classes 100 --widen_factor 8 --weight_decay 0.001
To reproduce the results on CIFAR100, the --widen_factor
has to be increased to --widen_factor=8
. (see this issue in the official repo.), and --weight_decay=0.001
.
In this repo, we use WideResNet with LeakyReLU activations, implemented in models/net/wrn.py
.
When you use the WideResNet, you can change widen_factor, leaky_slope, and dropRate by the argument changes.
For example,
If you want to use ReLU, just use --leaky_slope 0.0
in arugments.
Also, we support to use various backbone networks in torchvision.models
.
If you want to use other backbone networks in torchvision, change the arguments
--net [MODEL's NAME in torchvision] --net_from_name True
when --net_from_name True
, other model arguments are ignored except --net
.
If you want to use mixed-precision training for speed-up, add --amp
in the argument.
We checked that the training time of each iteration is reduced by about 20-30 %.
We trace various metrics, including training accuracy, prefetch & run times, mask ratio of unlabeled data, and learning rates. See the details in here. You can see the metrics in tensorboard
tensorboard --logdir=[SAVE PATH] --port=[YOUR PORT]