Skip to content

Commit

Permalink
tf_efficientnet_b4_ns inference
Browse files Browse the repository at this point in the history
  • Loading branch information
IMOKURI committed Jan 15, 2021
1 parent 9a42d4b commit e10c0bd
Showing 1 changed file with 20 additions and 84 deletions.
104 changes: 20 additions & 84 deletions cassava-resnext50-32x4d-starter-inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
"metadata": {},
"source": [
"# About this notebook \n",
"- PyTorch resnext50_32x4d starter code \n",
"- StratifiedKFold 5 folds \n",
"- training code is [here](https://www.kaggle.com/yasufuminakama/cassava-resnext50-32x4d-starter-training)\n",
"\n",
"If this notebook is helpful, feel free to upvote :)"
"TBD...\n"
]
},
{
Expand Down Expand Up @@ -72,15 +69,15 @@
"class CFG:\n",
" debug = False\n",
" num_workers = 4\n",
" model_name = \"resnext50_32x4d\"\n",
" model_name = \"tf_efficientnet_b4_ns\" # resnext50_32x4d, tf_efficientnet_b3_ns, tf_efficientnet_b4_ns\n",
" size = 512 # 512 if ON_KAGGLE else 384\n",
" batch_size = 14\n",
" batch_size = 8 # resnext50_32x4d: 14, tf_efficientnet_b3_ns:10, tf_efficientnet_b4_ns: 8\n",
" seed = 22\n",
" target_size = 5\n",
" target_col = \"label\"\n",
" n_fold = 5\n",
" trn_fold = [0, 1, 2, 3, 4]\n",
" tta = 10 # 1: no TTA, >1: TTA\n",
" tta = 1 # 1: no TTA, >1: TTA\n",
" train = False\n",
" inference = True"
]
Expand Down Expand Up @@ -377,8 +374,8 @@
" HorizontalFlip(p=0.5),\n",
" VerticalFlip(p=0.5),\n",
" # ShiftScaleRotate(p=0.5),\n",
" # HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),\n",
" # RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),\n",
" HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),\n",
" RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),\n",
" # CoarseDropout(p=0.5),\n",
" # Cutout(p=0.5),\n",
" Normalize(\n",
Expand Down Expand Up @@ -406,12 +403,18 @@
"# ====================================================\n",
"# MODEL\n",
"# ====================================================\n",
"class CustomResNext(nn.Module):\n",
"class CassvaImgClassifier(nn.Module):\n",
" def __init__(self, model_name=\"resnext50_32x4d\", pretrained=False):\n",
" super().__init__()\n",
" self.model = timm.create_model(model_name, pretrained=pretrained)\n",
" n_features = self.model.fc.in_features\n",
" self.model.fc = nn.Linear(n_features, CFG.target_size)\n",
"\n",
" if model_name == \"resnext50_32x4d\":\n",
" n_features = self.model.fc.in_features\n",
" self.model.fc = nn.Linear(n_features, CFG.target_size)\n",
"\n",
" elif model_name.startswith(\"tf_efficientnet\"):\n",
" n_features = self.model.classifier.in_features\n",
" self.model.classifier = nn.Linear(n_features, CFG.target_size)\n",
"\n",
" def forward(self, x):\n",
" x = self.model(x)\n",
Expand Down Expand Up @@ -468,7 +471,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -481,65 +484,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fd967ee248e0443a99fd8273e91328ec",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inference(TTA) example: [0.02446977 0.0249227 0.2398808 0.02260921 0.6881176 ]\n",
"========== TTA: 1 ==========\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b7b0e2809d14466a9966d45d63b5d428",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inference(TTA) example: [0.03281627 0.0269058 0.51045144 0.03546704 0.3943594 ]\n",
"========== TTA: 2 ==========\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "929cb6d2c5004f2692bec58711bd536e",
"model_id": "438824d275cd4f4aade6b2aff2769fa5",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -550,15 +495,6 @@
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inference(TTA) example: [0.03716791 0.03428895 0.31692576 0.04033972 0.57127774]\n",
"========== Overall ==========\n",
"Inference(overall) example: [0.03148465 0.02870581 0.35575268 0.03280532 0.5512516 ]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -606,7 +542,7 @@
"0 2216849948.jpg 4"
]
},
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -617,10 +553,10 @@
"# ====================================================\n",
"for i in range(CFG.tta):\n",
" LOGGER.info(f\"========== TTA: {i} ==========\")\n",
" model = CustomResNext(CFG.model_name, pretrained=False)\n",
" model = CassvaImgClassifier(CFG.model_name, pretrained=False)\n",
" states = [torch.load(MODEL_DIR + f\"{CFG.model_name}_fold{fold}_best.pth\") for fold in CFG.trn_fold]\n",
"\n",
" if CFG.tta == 1: # no TTA\n",
" if i == 0: # no TTA\n",
" test_dataset = TestDataset(test, transform=get_transforms(data=\"valid\"))\n",
" else:\n",
" test_dataset = TestDataset(test, transform=get_transforms(data=\"inference\"))\n",
Expand Down

0 comments on commit e10c0bd

Please sign in to comment.