Skip to content

Commit 123307f

Browse files
committed
update
1 parent e6661e2 commit 123307f

File tree

5 files changed

+12
-282
lines changed

5 files changed

+12
-282
lines changed

cnn_architecture

Lines changed: 0 additions & 240 deletions
This file was deleted.

cnn_architecture.png

-26.2 KB
Binary file not shown.

notebooks/EDA_and_Explainability.ipynb

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 1,
13+
"execution_count": null,
1414
"id": "4534a191",
1515
"metadata": {},
1616
"outputs": [
@@ -46,7 +46,6 @@
4646
"\n",
4747
"# --- Visualize some samples ---\n",
4848
"%matplotlib inline\n",
49-
"\n",
5049
"fig, axes = plt.subplots(2, 5, figsize=(6, 4))\n",
5150
"\n",
5251
"for ax in axes.flat:\n",
@@ -263,7 +262,7 @@
263262
},
264263
{
265264
"cell_type": "code",
266-
"execution_count": 6,
265+
"execution_count": null,
267266
"id": "4a1f7733",
268267
"metadata": {},
269268
"outputs": [
@@ -308,7 +307,7 @@
308307
"from src.models.cnn_model import CNNModel\n",
309308
"\n",
310309
"model = CNNModel(num_classes=10)\n",
311-
"summary(model, input_size=(1, 1, 28, 28))\n"
310+
"summary(model, input_size=(1, 1, 28, 28))"
312311
]
313312
},
314313
{
@@ -321,7 +320,7 @@
321320
},
322321
{
323322
"cell_type": "code",
324-
"execution_count": 7,
323+
"execution_count": null,
325324
"id": "dcfe4491",
326325
"metadata": {},
327326
"outputs": [
@@ -353,41 +352,7 @@
353352
" class_names=class_names,\n",
354353
" max_samples=10,\n",
355354
" layer_name=\"conv3\"\n",
356-
")\n"
357-
]
358-
},
359-
{
360-
"cell_type": "code",
361-
"execution_count": 8,
362-
"id": "37ec9dab",
363-
"metadata": {},
364-
"outputs": [
365-
{
366-
"name": "stderr",
367-
"output_type": "stream",
368-
"text": [
369-
"\n",
370-
"(process:23160): Pango-WARNING **: 18:20:07.026: couldn't load font \"Linux libertine Not-Rotated 10\", falling back to \"Sans Not-Rotated 10\", expect ugly output.\n"
371-
]
372-
},
373-
{
374-
"data": {
375-
"text/plain": [
376-
"'cnn_architecture.png'"
377-
]
378-
},
379-
"execution_count": 8,
380-
"metadata": {},
381-
"output_type": "execute_result"
382-
}
383-
],
384-
"source": [
385-
"from torchview import draw_graph\n",
386-
"from src.models.cnn_model import CNNModel\n",
387-
"\n",
388-
"model = CNNModel(num_classes=10)\n",
389-
"graph = draw_graph(model, input_size=(1, 1, 28, 28))\n",
390-
"graph.visual_graph.render(\"cnn_architecture\", format=\"png\")\n"
355+
")"
391356
]
392357
}
393358
],

requirements.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
torch
22
torchvision
33
matplotlib
4+
seaborn
45
numpy
5-
tqdm
6+
scikit-learn
7+
PyYAML
8+
opencv-python
9+
torchviz
10+
torchview

src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main():
3535
)
3636

3737
os.makedirs(config["train"]["checkpoint_dir"], exist_ok=True)
38-
best_path = os.path.join(config["train"]["checkpoint_dir"], "best_model.pth")
38+
# best_path = os.path.join(config["train"]["checkpoint_dir"], "best_model.pth")
3939

4040
train_model(
4141
model,

0 commit comments

Comments
 (0)