Skip to content

Commit 452de7b

Browse files
feat: create run file
1 parent d4cd095 commit 452de7b

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

data/train/__delete_me__

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Delete this file after cloning the repository

data/valid/__delete_me__

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Delete this file after cloning the repository

run.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import torch
3+
4+
from models.autoencoder import Autoencoder
5+
from utils.dataloader import get_dataloader
6+
from utils.trainer import train_autoencoder, visualize_reconstructions, save_model, load_model, evaluate_autoencoder
7+
from settings import settings
8+
9+
10+
def main(load_trained_model):
11+
BATCH_SIZE = 32
12+
INPUT_DIM = 3 * 64 * 64
13+
ENCODING_DIM = 12
14+
NUM_EPOCHS = 1000
15+
16+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
18+
dataloader = get_dataloader(settings.DATA_PATH, BATCH_SIZE)
19+
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
20+
21+
if load_trained_model:
22+
trained_model = load_model(model, settings.PATH_SAVED_MODEL, device=device)
23+
else:
24+
trained_model = train_autoencoder(model, dataloader, NUM_EPOCHS, device=device)
25+
26+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)
27+
28+
save_path = os.path.join('./', settings.PATH_SAVED_MODEL)
29+
save_model(trained_model, save_path)
30+
print(f"Model saved to {save_path}")
31+
32+
avg_valid_loss = evaluate_autoencoder(trained_model, valid_dataloader, device)
33+
print(f"Average validation loss: {avg_valid_loss:.4f}")
34+
35+
visualize_reconstructions(trained_model, valid_dataloader, num_samples=10, device=device)
36+
37+
38+
if __name__ == "__main__":
39+
main(False)

0 commit comments

Comments
 (0)