Skip to content

Commit

Permalink
TTA x10
Browse files Browse the repository at this point in the history
  • Loading branch information
IMOKURI committed Jan 13, 2021
1 parent c8175bb commit 8380374
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
| same above | [v3-TTAx10] | resnext50_32x4d | 0.893 | same above | TTA x10 mean() |
| same above | [v3-TTAx15] | resnext50_32x4d | 0.894 | same above | TTA x15 mean() |
| [v4-train] | [v4-inf] | resnext50_32x4d | 0.895 | 0.88822 | add augmentation functions, no TTA |
| same above | [v4-TTAx3] | resnext50_32x4d | 0.893 | same above | TTA x3 mean() |

## Memo

Expand All @@ -32,14 +33,19 @@
### Model

- [x] Resnext50_32x4d
- [ ] EfficientNet
- [ ] EfficientNet B3, B4 with Noisy Student
- vision transformer

### Loss

- [x] [Bi-Tempered Logistic Loss](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/202017)
- [x] [with label smoothing](https://www.kaggle.com/piantic/train-cassava-starter-using-various-loss-funcs/notebook#Bi-Tempered-Loss)

### Training

- [batch normalization layers frozen for EfficientNet](https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/#tips-for-fine-tuning-efficientnet)
- early stopping

### Inference

- [x] [TTA(Test Time Augmentation)](https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-inference-tta)
Expand All @@ -48,6 +54,7 @@

- [A few things for easy start](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/207450)
- [Sharing some improvements and experiments](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203594)
- [Important points to boost the LB score](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/208402)


[v1-train]: https://github.com/IMOKURI/Cassava-Leaf-Disease-Classification/commit/59a171a0e4ee6c8d7f87a3e9248333506a466405
Expand All @@ -60,3 +67,4 @@
[v3-TTAx15]: https://github.com/IMOKURI/Cassava-Leaf-Disease-Classification/commit/7297aecb96fc1630178344702f5466c50bd1c836
[v4-train]: https://github.com/IMOKURI/Cassava-Leaf-Disease-Classification/commit/c88d247a84fd424d58403437888346e458466a1c
[v4-inf]: https://github.com/IMOKURI/Cassava-Leaf-Disease-Classification/commit/da37e635677cefd6df64f5ff38d286f336af7b92
[v4-TTAx3]: https://github.com/IMOKURI/Cassava-Leaf-Disease-Classification/commit/af68da580b9ab946e423da2199adb95c8956ca43
9 changes: 7 additions & 2 deletions cassava-resnext50-32x4d-starter-inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
" target_col = \"label\"\n",
" n_fold = 5\n",
" trn_fold = [0, 1, 2, 3, 4]\n",
" tta = 3\n",
" tta = 10 # 1: no TTA, >1: TTA\n",
" train = False\n",
" inference = True"
]
Expand Down Expand Up @@ -619,7 +619,12 @@
" LOGGER.info(f\"========== TTA: {i} ==========\")\n",
" model = CustomResNext(CFG.model_name, pretrained=False)\n",
" states = [torch.load(MODEL_DIR + f\"{CFG.model_name}_fold{fold}_best.pth\") for fold in CFG.trn_fold]\n",
" test_dataset = TestDataset(test, transform=get_transforms(data=\"inference\"))\n",
"\n",
" if CFG.tta == 1: # no TTA\n",
" test_dataset = TestDataset(test, transform=get_transforms(data=\"valid\"))\n",
" else:\n",
" test_dataset = TestDataset(test, transform=get_transforms(data=\"inference\"))\n",
"\n",
" test_loader = DataLoader(\n",
" test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True\n",
" )\n",
Expand Down

0 comments on commit 8380374

Please sign in to comment.