Skip to content

Commit 71d2079

Browse files
committed
embetter mobilenet
1 parent 2eeedbd commit 71d2079

File tree

5 files changed

+2163
-0
lines changed

5 files changed

+2163
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ A collection of various deep learning architectures, models, and tips for Tensor
9191
| MobileNet-v2 on Cifar-10 | TBD | TBD | [![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet)](pytorch-lightning_ipynb/cnn/cnn-mobilenet-v2-cifar10.ipynb) [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-mobilenet-v2-cifar10.ipynb) |
9292
| MobileNet-v3 small on Cifar-10 | TBD | TBD | [![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet)](pytorch-lightning_ipynb/cnn/cnn-mobilenet-v3-small-cifar10.ipynb) [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-mobilenet-v3-small-cifar10.ipynb) |
9393
| MobileNet-v3 large on Cifar-10 | TBD | TBD | [![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet)](pytorch-lightning_ipynb/cnn/cnn-mobilenet-v3-large-cifar10.ipynb) [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-mobilenet-v3-large-cifar10.ipynb) |
94+
| MobileNet-v3 large on MNIST via Embetter | TBD | TBD | [![PyTorch](https://img.shields.io/badge/Py-Torch-red)](pytorch_ipynb/cnn/cnn-embetter-mobilenet.ipynb) |
9495

9596

9697

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "4936d1e6-5e7d-4e22-ae35-8e888927ce2d",
6+
"metadata": {},
7+
"source": [
8+
"# Use Pre-trained CNN as feature extractor"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "bf9e9fb5-7383-475a-93e1-decdbd59c247",
14+
"metadata": {},
15+
"source": [
16+
"Use MobileNetv3 as a feature extractor via the [embetter](https://github.com/koaning/embetter) scikit-learn library and [timm](https://github.com/rwightman/pytorch-image-models). Train a logistic regression classifier in scikit-learn on the embeddings."
17+
]
18+
},
19+
{
20+
"cell_type": "markdown",
21+
"id": "96b717c7-54c9-40dc-ba80-0fb47da2c0bd",
22+
"metadata": {},
23+
"source": [
24+
"![](images/feature-extractor.png)"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 1,
30+
"id": "64d1dd64-c45b-4092-84d1-1bfcd0998f15",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"import os\n",
35+
"\n",
36+
"# pip install gitpython\n",
37+
"from git import Repo\n",
38+
"\n",
39+
"if not os.path.exists(\"mnist-pngs\"):\n",
40+
" Repo.clone_from(\"https://github.com/rasbt/mnist-pngs\", \"mnist-pngs\")"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 2,
46+
"id": "3a892538-8d9b-4420-9525-26d1a4b37ae3",
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"import os\n",
51+
"import pandas as pd\n",
52+
"\n",
53+
"for name in (\"train\", \"test\"):\n",
54+
"\n",
55+
" df = pd.read_csv(f\"mnist-pngs/{name}.csv\")\n",
56+
" df[\"filepath\"] = df[\"filepath\"].apply(lambda x: \"mnist-pngs/\" + x)\n",
57+
" df = df.sample(frac=1, random_state=123).reset_index(drop=True)\n",
58+
" df.to_csv(f\"mnist-pngs/{name}_shuffled.csv\", index=None)"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": 3,
64+
"id": "5885e9bb-d43f-46ca-83ae-e2d63edcbb37",
65+
"metadata": {},
66+
"outputs": [
67+
{
68+
"data": {
69+
"application/vnd.jupyter.widget-view+json": {
70+
"model_id": "1fba0fcb2b1f408f85013da0d1694dd3",
71+
"version_major": 2,
72+
"version_minor": 0
73+
},
74+
"text/plain": [
75+
" 0%| | 0/60 [00:00<?, ?it/s]"
76+
]
77+
},
78+
"metadata": {},
79+
"output_type": "display_data"
80+
}
81+
],
82+
"source": [
83+
"from sklearn.pipeline import make_pipeline\n",
84+
"from sklearn.linear_model import SGDClassifier\n",
85+
"from tqdm.notebook import tqdm\n",
86+
"\n",
87+
"# pip install \"embetter[vision]\"\n",
88+
"from embetter.vision import ImageLoader, TimmEncoder\n",
89+
"\n",
90+
"\n",
91+
"embed = make_pipeline(\n",
92+
" ImageLoader(),\n",
93+
" TimmEncoder(name=\"mobilenetv3_large_100\")\n",
94+
")\n",
95+
"\n",
96+
"model = SGDClassifier(loss='log_loss', n_jobs=-1, shuffle=True)\n",
97+
"\n",
98+
"chunksize = 1000\n",
99+
"train_labels, train_predict = [], []\n",
100+
"\n",
101+
"for df in tqdm(pd.read_csv(\"mnist-pngs/train_shuffled.csv\", chunksize=chunksize, iterator=True), total=60):\n",
102+
" \n",
103+
" embedded = embed.transform(df[\"filepath\"])\n",
104+
" model.partial_fit(embedded, df[\"label\"], classes=list(range(10)))"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 4,
110+
"id": "999a24ea-be5d-425f-923c-266372c66b5d",
111+
"metadata": {},
112+
"outputs": [
113+
{
114+
"data": {
115+
"application/vnd.jupyter.widget-view+json": {
116+
"model_id": "157302965ac8460c97c77935cc08e1fc",
117+
"version_major": 2,
118+
"version_minor": 0
119+
},
120+
"text/plain": [
121+
" 0%| | 0/60 [00:00<?, ?it/s]"
122+
]
123+
},
124+
"metadata": {},
125+
"output_type": "display_data"
126+
}
127+
],
128+
"source": [
129+
"train_labels, train_predict = [], []\n",
130+
"\n",
131+
"for df in tqdm(pd.read_csv(\"mnist-pngs/train.csv\", chunksize=chunksize, iterator=True), total=60):\n",
132+
" df[\"filepath\"] = df[\"filepath\"].apply(lambda x: \"mnist-pngs/\" + x)\n",
133+
"\n",
134+
" embedded = embed.transform(df[\"filepath\"])\n",
135+
" train_predict.extend(model.predict(embedded))\n",
136+
" train_labels.extend(list(df[\"label\"].values))"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": 5,
142+
"id": "c816cd7b-ed3a-4cb2-8aa6-400068a2e414",
143+
"metadata": {},
144+
"outputs": [
145+
{
146+
"data": {
147+
"application/vnd.jupyter.widget-view+json": {
148+
"model_id": "7869826407314279a2806bf602a796a8",
149+
"version_major": 2,
150+
"version_minor": 0
151+
},
152+
"text/plain": [
153+
" 0%| | 0/10 [00:00<?, ?it/s]"
154+
]
155+
},
156+
"metadata": {},
157+
"output_type": "display_data"
158+
}
159+
],
160+
"source": [
161+
"test_labels, test_predict = [], []\n",
162+
"\n",
163+
"for df in tqdm(pd.read_csv(\"mnist-pngs/test_shuffled.csv\", chunksize=chunksize, iterator=True), total=10):\n",
164+
"\n",
165+
" embedded = embed.transform(df[\"filepath\"])\n",
166+
" test_predict.extend(model.predict(embedded))\n",
167+
" test_labels.extend(list(df[\"label\"].values))"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": 6,
173+
"id": "4a78add1-7f93-40fc-b119-9dbbe0aa55b4",
174+
"metadata": {},
175+
"outputs": [
176+
{
177+
"name": "stdout",
178+
"output_type": "stream",
179+
"text": [
180+
"Train accuracy: 0.92\n",
181+
"Test accuracy: 0.92\n"
182+
]
183+
}
184+
],
185+
"source": [
186+
"from sklearn.metrics import accuracy_score\n",
187+
"\n",
188+
"print(f\"Train accuracy: {accuracy_score(train_labels, train_predict):.2f}\")\n",
189+
"print(f\"Test accuracy: {accuracy_score(test_labels, test_predict):.2f}\")"
190+
]
191+
}
192+
],
193+
"metadata": {
194+
"kernelspec": {
195+
"display_name": "Python 3 (ipykernel)",
196+
"language": "python",
197+
"name": "python3"
198+
},
199+
"language_info": {
200+
"codemirror_mode": {
201+
"name": "ipython",
202+
"version": 3
203+
},
204+
"file_extension": ".py",
205+
"mimetype": "text/x-python",
206+
"name": "python",
207+
"nbconvert_exporter": "python",
208+
"pygments_lexer": "ipython3",
209+
"version": "3.9.7"
210+
}
211+
},
212+
"nbformat": 4,
213+
"nbformat_minor": 5
214+
}
181 KB
Loading

0 commit comments

Comments
 (0)