Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix link in the CCT example #541

Merged
merged 4 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/vision/cct.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,15 @@ def run_experiment(model):
"""
The CCT model we just trained has just **0.4 million** parameters, and it gets us to
~78% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as
well. This means we can train this network for longers (perhaps with a bit more
well. This means we can train this network for longer (perhaps with a bit more
regularization) and may obtain even better performance. This performance can further be
improved by additional recipes like cosine decay learning rate schedule, other data augmentation
techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),
[MixUp](https://arxiv.org/abs/1710.09412) or
[Cutmix](https://arxiv.org/abs/1905.04899. The authors also present a number of
[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present
95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of
experiments to study how the number of convolution blocks, Transformers layers, etc.
affect the final performance.
affect the final performance of CCTs.

For a comparison, a ViT model takes about **4.7 million** parameters and **100
epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can
Expand Down
Binary file modified examples/vision/img/cct/cct_22_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions examples/vision/ipynb/cct.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,15 @@
"source": [
"The CCT model we just trained has just **0.4 million** parameters, and it gets us to\n",
"~78% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as\n",
"well. This means we can train this network for longers (perhaps with a bit more\n",
"well. This means we can train this network for longer (perhaps with a bit more\n",
"regularization) and may obtain even better performance. This performance can further be\n",
"improved by additional recipes like cosine decay learning rate schedule, other data augmentation\n",
"techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),\n",
"[MixUp](https://arxiv.org/abs/1710.09412) or\n",
"[Cutmix](https://arxiv.org/abs/1905.04899. The authors also present a number of\n",
"[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present\n",
"95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of\n",
"experiments to study how the number of convolution blocks, Transformers layers, etc.\n",
"affect the final performance.\n",
"affect the final performance of CCTs.\n",
"\n",
"For a comparison, a ViT model takes about **4.7 million** parameters and **100\n",
"epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can\n",
Expand Down
76 changes: 39 additions & 37 deletions examples/vision/md/cct.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Compact Convolutional Transformers

**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>
Expand Down Expand Up @@ -43,7 +44,7 @@ be installed using the following command:

<div class="k-default-codeblock">
```
 |████████████████████████████████| 686kB 5.3MB/s
 |████████████████████████████████| 686kB 5.4MB/s
[?25h

```
Expand Down Expand Up @@ -390,68 +391,68 @@ history = run_experiment(cct_model)
<div class="k-default-codeblock">
```
Epoch 1/30
352/352 [==============================] - 10s 17ms/step - loss: 1.9019 - accuracy: 0.3357 - top-5-accuracy: 0.8326 - val_loss: 1.6537 - val_accuracy: 0.4596 - val_top-5-accuracy: 0.9206
352/352 [==============================] - 10s 18ms/step - loss: 1.9181 - accuracy: 0.3277 - top-5-accuracy: 0.8296 - val_loss: 1.7123 - val_accuracy: 0.4250 - val_top-5-accuracy: 0.9028
Epoch 2/30
352/352 [==============================] - 5s 15ms/step - loss: 1.5560 - accuracy: 0.5058 - top-5-accuracy: 0.9341 - val_loss: 1.4756 - val_accuracy: 0.5466 - val_top-5-accuracy: 0.9462
352/352 [==============================] - 6s 16ms/step - loss: 1.5725 - accuracy: 0.5010 - top-5-accuracy: 0.9295 - val_loss: 1.5026 - val_accuracy: 0.5530 - val_top-5-accuracy: 0.9364
Epoch 3/30
352/352 [==============================] - 5s 15ms/step - loss: 1.4379 - accuracy: 0.5646 - top-5-accuracy: 0.9527 - val_loss: 1.3775 - val_accuracy: 0.6016 - val_top-5-accuracy: 0.9622
352/352 [==============================] - 6s 16ms/step - loss: 1.4492 - accuracy: 0.5633 - top-5-accuracy: 0.9476 - val_loss: 1.3744 - val_accuracy: 0.6038 - val_top-5-accuracy: 0.9558
Epoch 4/30
352/352 [==============================] - 5s 15ms/step - loss: 1.3568 - accuracy: 0.6067 - top-5-accuracy: 0.9611 - val_loss: 1.3125 - val_accuracy: 0.6288 - val_top-5-accuracy: 0.9658
352/352 [==============================] - 6s 16ms/step - loss: 1.3658 - accuracy: 0.6055 - top-5-accuracy: 0.9576 - val_loss: 1.3258 - val_accuracy: 0.6148 - val_top-5-accuracy: 0.9648
Epoch 5/30
352/352 [==============================] - 5s 15ms/step - loss: 1.2905 - accuracy: 0.6386 - top-5-accuracy: 0.9668 - val_loss: 1.2665 - val_accuracy: 0.6506 - val_top-5-accuracy: 0.9712
352/352 [==============================] - 6s 16ms/step - loss: 1.3142 - accuracy: 0.6302 - top-5-accuracy: 0.9640 - val_loss: 1.2723 - val_accuracy: 0.6468 - val_top-5-accuracy: 0.9710
Epoch 6/30
352/352 [==============================] - 5s 15ms/step - loss: 1.2438 - accuracy: 0.6612 - top-5-accuracy: 0.9710 - val_loss: 1.2220 - val_accuracy: 0.6740 - val_top-5-accuracy: 0.9728
352/352 [==============================] - 6s 16ms/step - loss: 1.2729 - accuracy: 0.6489 - top-5-accuracy: 0.9684 - val_loss: 1.2490 - val_accuracy: 0.6640 - val_top-5-accuracy: 0.9704
Epoch 7/30
352/352 [==============================] - 5s 15ms/step - loss: 1.2150 - accuracy: 0.6753 - top-5-accuracy: 0.9743 - val_loss: 1.2013 - val_accuracy: 0.6802 - val_top-5-accuracy: 0.9772
352/352 [==============================] - 6s 16ms/step - loss: 1.2371 - accuracy: 0.6664 - top-5-accuracy: 0.9711 - val_loss: 1.1822 - val_accuracy: 0.6906 - val_top-5-accuracy: 0.9744
Epoch 8/30
352/352 [==============================] - 5s 15ms/step - loss: 1.1807 - accuracy: 0.6922 - top-5-accuracy: 0.9762 - val_loss: 1.2122 - val_accuracy: 0.6808 - val_top-5-accuracy: 0.9710
352/352 [==============================] - 6s 16ms/step - loss: 1.1899 - accuracy: 0.6942 - top-5-accuracy: 0.9735 - val_loss: 1.1799 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9768
Epoch 9/30
352/352 [==============================] - 5s 15ms/step - loss: 1.1464 - accuracy: 0.7075 - top-5-accuracy: 0.9792 - val_loss: 1.1697 - val_accuracy: 0.6974 - val_top-5-accuracy: 0.9798
352/352 [==============================] - 6s 16ms/step - loss: 1.1706 - accuracy: 0.6972 - top-5-accuracy: 0.9767 - val_loss: 1.1390 - val_accuracy: 0.7148 - val_top-5-accuracy: 0.9768
Epoch 10/30
352/352 [==============================] - 5s 15ms/step - loss: 1.1294 - accuracy: 0.7148 - top-5-accuracy: 0.9800 - val_loss: 1.1683 - val_accuracy: 0.6992 - val_top-5-accuracy: 0.9750
352/352 [==============================] - 6s 16ms/step - loss: 1.1524 - accuracy: 0.7054 - top-5-accuracy: 0.9783 - val_loss: 1.1803 - val_accuracy: 0.7000 - val_top-5-accuracy: 0.9740
Epoch 11/30
352/352 [==============================] - 5s 15ms/step - loss: 1.1030 - accuracy: 0.7258 - top-5-accuracy: 0.9818 - val_loss: 1.1785 - val_accuracy: 0.6946 - val_top-5-accuracy: 0.9770
352/352 [==============================] - 6s 16ms/step - loss: 1.1219 - accuracy: 0.7222 - top-5-accuracy: 0.9798 - val_loss: 1.1066 - val_accuracy: 0.7254 - val_top-5-accuracy: 0.9812
Epoch 12/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0928 - accuracy: 0.7315 - top-5-accuracy: 0.9827 - val_loss: 1.0762 - val_accuracy: 0.7460 - val_top-5-accuracy: 0.9828
352/352 [==============================] - 6s 16ms/step - loss: 1.1029 - accuracy: 0.7287 - top-5-accuracy: 0.9811 - val_loss: 1.0844 - val_accuracy: 0.7388 - val_top-5-accuracy: 0.9814
Epoch 13/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0739 - accuracy: 0.7436 - top-5-accuracy: 0.9837 - val_loss: 1.1078 - val_accuracy: 0.7296 - val_top-5-accuracy: 0.9844
352/352 [==============================] - 6s 16ms/step - loss: 1.0841 - accuracy: 0.7380 - top-5-accuracy: 0.9825 - val_loss: 1.1159 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9792
Epoch 14/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0577 - accuracy: 0.7509 - top-5-accuracy: 0.9843 - val_loss: 1.0919 - val_accuracy: 0.7384 - val_top-5-accuracy: 0.9814
352/352 [==============================] - 6s 16ms/step - loss: 1.0677 - accuracy: 0.7462 - top-5-accuracy: 0.9832 - val_loss: 1.0862 - val_accuracy: 0.7444 - val_top-5-accuracy: 0.9834
Epoch 15/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0436 - accuracy: 0.7570 - top-5-accuracy: 0.9849 - val_loss: 1.1271 - val_accuracy: 0.7206 - val_top-5-accuracy: 0.9804
352/352 [==============================] - 6s 16ms/step - loss: 1.0511 - accuracy: 0.7535 - top-5-accuracy: 0.9846 - val_loss: 1.0613 - val_accuracy: 0.7494 - val_top-5-accuracy: 0.9832
Epoch 16/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0245 - accuracy: 0.7651 - top-5-accuracy: 0.9855 - val_loss: 1.0777 - val_accuracy: 0.7452 - val_top-5-accuracy: 0.9826
352/352 [==============================] - 6s 16ms/step - loss: 1.0377 - accuracy: 0.7608 - top-5-accuracy: 0.9854 - val_loss: 1.0379 - val_accuracy: 0.7606 - val_top-5-accuracy: 0.9834
Epoch 17/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0231 - accuracy: 0.7653 - top-5-accuracy: 0.9860 - val_loss: 1.0474 - val_accuracy: 0.7608 - val_top-5-accuracy: 0.9868
352/352 [==============================] - 6s 16ms/step - loss: 1.0304 - accuracy: 0.7650 - top-5-accuracy: 0.9849 - val_loss: 1.0602 - val_accuracy: 0.7562 - val_top-5-accuracy: 0.9814
Epoch 18/30
352/352 [==============================] - 5s 15ms/step - loss: 1.0091 - accuracy: 0.7713 - top-5-accuracy: 0.9876 - val_loss: 1.0785 - val_accuracy: 0.7468 - val_top-5-accuracy: 0.9808
352/352 [==============================] - 6s 16ms/step - loss: 1.0121 - accuracy: 0.7746 - top-5-accuracy: 0.9869 - val_loss: 1.0430 - val_accuracy: 0.7630 - val_top-5-accuracy: 0.9834
Epoch 19/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9959 - accuracy: 0.7800 - top-5-accuracy: 0.9880 - val_loss: 1.0574 - val_accuracy: 0.7522 - val_top-5-accuracy: 0.9830
352/352 [==============================] - 6s 16ms/step - loss: 1.0037 - accuracy: 0.7760 - top-5-accuracy: 0.9872 - val_loss: 1.0951 - val_accuracy: 0.7460 - val_top-5-accuracy: 0.9826
Epoch 20/30
352/352 [==============================] - 5s 16ms/step - loss: 0.9902 - accuracy: 0.7792 - top-5-accuracy: 0.9883 - val_loss: 1.1174 - val_accuracy: 0.7354 - val_top-5-accuracy: 0.9834
352/352 [==============================] - 6s 16ms/step - loss: 0.9964 - accuracy: 0.7805 - top-5-accuracy: 0.9871 - val_loss: 1.0683 - val_accuracy: 0.7538 - val_top-5-accuracy: 0.9834
Epoch 21/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9855 - accuracy: 0.7830 - top-5-accuracy: 0.9883 - val_loss: 1.0374 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9850
352/352 [==============================] - 6s 16ms/step - loss: 0.9838 - accuracy: 0.7850 - top-5-accuracy: 0.9886 - val_loss: 1.0185 - val_accuracy: 0.7770 - val_top-5-accuracy: 0.9876
Epoch 22/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9750 - accuracy: 0.7890 - top-5-accuracy: 0.9898 - val_loss: 1.0547 - val_accuracy: 0.7570 - val_top-5-accuracy: 0.9824
352/352 [==============================] - 6s 16ms/step - loss: 0.9742 - accuracy: 0.7904 - top-5-accuracy: 0.9894 - val_loss: 1.0253 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9838
Epoch 23/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9696 - accuracy: 0.7903 - top-5-accuracy: 0.9898 - val_loss: 1.0271 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9856
352/352 [==============================] - 6s 16ms/step - loss: 0.9662 - accuracy: 0.7935 - top-5-accuracy: 0.9889 - val_loss: 1.0107 - val_accuracy: 0.7786 - val_top-5-accuracy: 0.9860
Epoch 24/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9634 - accuracy: 0.7957 - top-5-accuracy: 0.9890 - val_loss: 1.0197 - val_accuracy: 0.7742 - val_top-5-accuracy: 0.9864
352/352 [==============================] - 6s 16ms/step - loss: 0.9549 - accuracy: 0.7994 - top-5-accuracy: 0.9897 - val_loss: 1.0089 - val_accuracy: 0.7790 - val_top-5-accuracy: 0.9852
Epoch 25/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9513 - accuracy: 0.8004 - top-5-accuracy: 0.9898 - val_loss: 1.0614 - val_accuracy: 0.7590 - val_top-5-accuracy: 0.9826
352/352 [==============================] - 6s 16ms/step - loss: 0.9522 - accuracy: 0.8018 - top-5-accuracy: 0.9896 - val_loss: 1.0214 - val_accuracy: 0.7780 - val_top-5-accuracy: 0.9866
Epoch 26/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9498 - accuracy: 0.8014 - top-5-accuracy: 0.9897 - val_loss: 1.0088 - val_accuracy: 0.7792 - val_top-5-accuracy: 0.9858
352/352 [==============================] - 6s 16ms/step - loss: 0.9469 - accuracy: 0.8023 - top-5-accuracy: 0.9897 - val_loss: 0.9993 - val_accuracy: 0.7816 - val_top-5-accuracy: 0.9882
Epoch 27/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9393 - accuracy: 0.8040 - top-5-accuracy: 0.9904 - val_loss: 1.0632 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9808
352/352 [==============================] - 6s 16ms/step - loss: 0.9463 - accuracy: 0.8022 - top-5-accuracy: 0.9906 - val_loss: 1.0071 - val_accuracy: 0.7848 - val_top-5-accuracy: 0.9850
Epoch 28/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9390 - accuracy: 0.8063 - top-5-accuracy: 0.9901 - val_loss: 1.0624 - val_accuracy: 0.7580 - val_top-5-accuracy: 0.9808
352/352 [==============================] - 6s 16ms/step - loss: 0.9336 - accuracy: 0.8077 - top-5-accuracy: 0.9909 - val_loss: 1.0113 - val_accuracy: 0.7868 - val_top-5-accuracy: 0.9856
Epoch 29/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9421 - accuracy: 0.8045 - top-5-accuracy: 0.9901 - val_loss: 1.0095 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9870
352/352 [==============================] - 6s 16ms/step - loss: 0.9352 - accuracy: 0.8071 - top-5-accuracy: 0.9909 - val_loss: 1.0073 - val_accuracy: 0.7856 - val_top-5-accuracy: 0.9830
Epoch 30/30
352/352 [==============================] - 5s 15ms/step - loss: 0.9234 - accuracy: 0.8108 - top-5-accuracy: 0.9915 - val_loss: 1.0183 - val_accuracy: 0.7808 - val_top-5-accuracy: 0.9838
313/313 [==============================] - 2s 5ms/step - loss: 1.0569 - accuracy: 0.7645 - top-5-accuracy: 0.9827
Test accuracy: 76.45%
Test top 5 accuracy: 98.27%
352/352 [==============================] - 6s 16ms/step - loss: 0.9273 - accuracy: 0.8112 - top-5-accuracy: 0.9908 - val_loss: 1.0144 - val_accuracy: 0.7792 - val_top-5-accuracy: 0.9836
313/313 [==============================] - 2s 6ms/step - loss: 1.0396 - accuracy: 0.7676 - top-5-accuracy: 0.9839
Test accuracy: 76.76%
Test top 5 accuracy: 98.39%

```
</div>
Expand All @@ -475,14 +476,15 @@ plt.show()

The CCT model we just trained has just **0.4 million** parameters, and it gets us to
~78% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as
well. This means we can train this network for longers (perhaps with a bit more
well. This means we can train this network for longer (perhaps with a bit more
regularization) and may obtain even better performance. This performance can further be
improved by additional recipes like cosine decay learning rate schedule, other data augmentation
techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),
[MixUp](https://arxiv.org/abs/1710.09412) or
[Cutmix](https://arxiv.org/abs/1905.04899. The authors also present a number of
[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present
95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of
experiments to study how the number of convolution blocks, Transformers layers, etc.
affect the final performance.
affect the final performance of CCTs.

For a comparison, a ViT model takes about **4.7 million** parameters and **100
epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can
Expand Down