Skip to content

Commit 2880ca9

Browse files
Merge pull request #3 from renan-siqueira/feature/ProjectFeatures
Project features
2 parents 02a6911 + 452de7b commit 2880ca9

File tree

7 files changed

+158
-0
lines changed

7 files changed

+158
-0
lines changed

data/train/__delete_me__

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Delete this file after cloning the repository

data/valid/__delete_me__

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Delete this file after cloning the repository

run.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import torch
3+
4+
from models.autoencoder import Autoencoder
5+
from utils.dataloader import get_dataloader
6+
from utils.trainer import train_autoencoder, visualize_reconstructions, save_model, load_model, evaluate_autoencoder
7+
from settings import settings
8+
9+
10+
def main(load_trained_model):
11+
BATCH_SIZE = 32
12+
INPUT_DIM = 3 * 64 * 64
13+
ENCODING_DIM = 12
14+
NUM_EPOCHS = 1000
15+
16+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17+
18+
dataloader = get_dataloader(settings.DATA_PATH, BATCH_SIZE)
19+
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
20+
21+
if load_trained_model:
22+
trained_model = load_model(model, settings.PATH_SAVED_MODEL, device=device)
23+
else:
24+
trained_model = train_autoencoder(model, dataloader, NUM_EPOCHS, device=device)
25+
26+
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)
27+
28+
save_path = os.path.join('./', settings.PATH_SAVED_MODEL)
29+
save_model(trained_model, save_path)
30+
print(f"Model saved to {save_path}")
31+
32+
avg_valid_loss = evaluate_autoencoder(trained_model, valid_dataloader, device)
33+
print(f"Average validation loss: {avg_valid_loss:.4f}")
34+
35+
visualize_reconstructions(trained_model, valid_dataloader, num_samples=10, device=device)
36+
37+
38+
if __name__ == "__main__":
39+
main(False)

settings/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
DATA_PATH = './data/train'
2+
VALID_DATA_PATH = './data/valid'
3+
PATH_SAVED_MODEL = './autoencoder_model.pth'

utils/__init__.py

Whitespace-only changes.

utils/dataloader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import torch
3+
from torchvision import datasets, transforms
4+
from torchvision.transforms import ToTensor, Resize, Compose
5+
from torch.utils.data import DataLoader, Dataset
6+
from PIL import Image
7+
8+
9+
def get_dataloader(data_path, batch_size):
10+
dataset = CustomDataset(data_path)
11+
12+
dataloader = DataLoader(
13+
dataset,
14+
batch_size=batch_size,
15+
shuffle=True
16+
)
17+
18+
return dataloader
19+
20+
21+
class CustomDataset(Dataset):
22+
def __init__(self, data_path):
23+
self.data_path = data_path
24+
self.image_files = os.listdir(data_path)
25+
26+
self.transforms = Compose([
27+
Resize((64, 64)),
28+
ToTensor()
29+
])
30+
31+
def __len__(self):
32+
return len(self.image_files)
33+
34+
def __getitem__(self, idx):
35+
image_path = os.path.join(self.data_path, self.image_files[idx])
36+
image = Image.open(image_path).convert('RGB')
37+
image = self.transforms(image)
38+
return image

utils/trainer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
from torchvision import transforms
6+
from torchvision.utils import save_image, make_grid
7+
import matplotlib.pyplot as plt
8+
from PIL import Image
9+
10+
11+
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu'):
12+
criterion = nn.MSELoss()
13+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
14+
15+
for epoch in range(num_epochs):
16+
for data in dataloader:
17+
img = data.to(device)
18+
img = img.view(img.size(0), -1)
19+
output = model(img)
20+
loss = criterion(output, img)
21+
22+
optimizer.zero_grad()
23+
loss.backward()
24+
optimizer.step()
25+
26+
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
27+
28+
return model
29+
30+
31+
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples"):
32+
model.eval()
33+
samples = next(iter(dataloader))
34+
samples = samples[:num_samples].to(device)
35+
samples = samples.view(samples.size(0), -1)
36+
reconstructions = model(samples)
37+
38+
samples = samples.view(-1, 3, 64, 64)
39+
reconstructions = reconstructions.view(-1, 3, 64, 64)
40+
41+
# Combine as amostras e reconstruções em uma única grade
42+
combined = torch.cat([samples, reconstructions], dim=0)
43+
grid_img = make_grid(combined, nrow=num_samples)
44+
45+
# Visualização usando Matplotlib
46+
plt.imshow(grid_img.permute(1, 2, 0).cpu().detach().numpy())
47+
plt.axis('off')
48+
plt.show()
49+
50+
if not os.path.exists(save_path):
51+
os.makedirs(save_path)
52+
save_image(grid_img, os.path.join(save_path, 'combined_samples.png'))
53+
54+
55+
def save_model(model, path):
56+
torch.save(model.state_dict(), path)
57+
58+
59+
def load_model(model, path, device):
60+
model.load_state_dict(torch.load(path, map_location=device))
61+
model.eval()
62+
return model
63+
64+
65+
def evaluate_autoencoder(model, dataloader, device):
66+
model.eval()
67+
total_loss = 0
68+
criterion = nn.MSELoss()
69+
with torch.no_grad():
70+
for data in dataloader:
71+
img = data.to(device)
72+
img = img.view(img.size(0), -1)
73+
output = model(img)
74+
loss = criterion(output, img)
75+
total_loss += loss.item()
76+
return total_loss / len(dataloader)

0 commit comments

Comments
 (0)