Skip to content

Commit 6633892

Browse files
Merge pull request #9 from renan-siqueira/develop
Merge develop into main
2 parents d0031c0 + 511be71 commit 6633892

File tree

3 files changed

+235
-44
lines changed

3 files changed

+235
-44
lines changed

models/autoencoder.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import torch
12
import torch.nn as nn
3+
import torch.nn.functional as F
24

35

6+
# Autoencoder Linear
47
class Autoencoder(nn.Module):
58
def __init__(self, input_dim, encoding_dim):
69
super(Autoencoder, self).__init__()
@@ -24,3 +27,127 @@ def forward(self, x):
2427
x = self.encoder(x)
2528
x = self.decoder(x)
2629
return x
30+
31+
32+
# Autoencoder Convolucional
33+
class ConvolutionalAutoencoder(nn.Module):
34+
def __init__(self):
35+
super(ConvolutionalAutoencoder, self).__init__()
36+
37+
# Encoder
38+
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
39+
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
40+
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
41+
self.pool = nn.MaxPool2d(2, 2, return_indices=True)
42+
43+
# Decoder
44+
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
45+
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
46+
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
47+
48+
def forward(self, x):
49+
x, idxs1 = self.pool(F.relu(self.enc1(x)))
50+
x, idxs2 = self.pool(F.relu(self.enc2(x)))
51+
x, idxs3 = self.pool(F.relu(self.enc3(x)))
52+
53+
x = F.relu(self.dec1(x))
54+
x = F.relu(self.dec2(x))
55+
x = torch.sigmoid(self.dec3(x))
56+
return x
57+
58+
59+
# Variational Autoencoder
60+
class VariationalAutoencoder(nn.Module):
61+
def __init__(self, encoding_dim=128):
62+
super(VariationalAutoencoder, self).__init__()
63+
64+
# Encoder
65+
self.enc1 = nn.Linear(3 * 64 * 64, 512)
66+
self.enc2 = nn.Linear(512, 256)
67+
self.enc3 = nn.Linear(256, encoding_dim)
68+
69+
# Latent space
70+
self.fc_mu = nn.Linear(encoding_dim, encoding_dim)
71+
self.fc_log_var = nn.Linear(encoding_dim, encoding_dim)
72+
73+
# Decoder
74+
self.dec1 = nn.Linear(encoding_dim, encoding_dim)
75+
self.dec2 = nn.Linear(encoding_dim, 256)
76+
self.dec3 = nn.Linear(256, 512)
77+
self.dec4 = nn.Linear(512, 3 * 64 * 64)
78+
79+
def reparameterize(self, mu, log_var):
80+
std = torch.exp(0.5 * log_var)
81+
eps = torch.randn_like(std)
82+
return mu + eps * std
83+
84+
def forward(self, x):
85+
x = F.relu(self.enc1(x))
86+
x = F.relu(self.enc2(x))
87+
x = F.relu(self.enc3(x))
88+
89+
mu = self.fc_mu(x)
90+
log_var = self.fc_log_var(x)
91+
z = self.reparameterize(mu, log_var)
92+
93+
x = F.relu(self.dec1(z))
94+
x = F.relu(self.dec2(x))
95+
x = F.relu(self.dec3(x))
96+
x = torch.sigmoid(self.dec4(x))
97+
98+
return x, mu, log_var
99+
100+
101+
# Convolucional Variational Autoencoder
102+
class ConvolutionalVAE(nn.Module):
103+
def __init__(self):
104+
super(ConvolutionalVAE, self).__init__()
105+
106+
# Encoder
107+
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
108+
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
109+
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
110+
self.pool = nn.MaxPool2d(2, 2)
111+
112+
self.fc_mu = nn.Linear(16 * 8 * 8, 128)
113+
self.fc_log_var = nn.Linear(16 * 8 * 8, 128)
114+
115+
# Decoder
116+
self.decoder_input = nn.Linear(128, 16 * 8 * 8)
117+
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1)
118+
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1)
119+
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1)
120+
121+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
122+
123+
def reparameterize(self, mu, log_var):
124+
std = torch.exp(0.5 * log_var)
125+
eps = torch.randn_like(std)
126+
return mu + eps * std
127+
128+
def forward(self, x):
129+
# Encoding
130+
x = F.relu(self.enc1(x))
131+
x = self.pool(x)
132+
x = F.relu(self.enc2(x))
133+
x = self.pool(x)
134+
x = F.relu(self.enc3(x))
135+
x = self.pool(x)
136+
137+
x = x.view(x.size(0), -1) # Flatten
138+
139+
mu = self.fc_mu(x)
140+
log_var = self.fc_log_var(x)
141+
z = self.reparameterize(mu, log_var)
142+
143+
# Decoding
144+
x = self.decoder_input(z)
145+
x = x.view(x.size(0), 16, 8, 8) # Unflatten
146+
x = self.upsample(x)
147+
x = F.relu(self.dec1(x))
148+
x = self.upsample(x)
149+
x = F.relu(self.dec2(x))
150+
x = self.upsample(x)
151+
x = torch.sigmoid(self.dec3(x))
152+
153+
return x, mu, log_var

run.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,61 @@
11
import os
22
import torch
33

4-
from models.autoencoder import Autoencoder
4+
from models.autoencoder import Autoencoder, ConvolutionalAutoencoder, VariationalAutoencoder, ConvolutionalVAE
55
from utils.dataloader import get_dataloader
6-
from utils.trainer import train_autoencoder, visualize_reconstructions, save_model, load_model, evaluate_autoencoder
6+
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
77
from settings import settings
88

99

10-
def main(load_trained_model):
10+
def main(load_trained_model, ae_type='ae'):
1111
BATCH_SIZE = 32
1212
INPUT_DIM = 3 * 64 * 64
13-
ENCODING_DIM = 12
14-
NUM_EPOCHS = 1000
13+
ENCODING_DIM = 64
14+
NUM_EPOCHS = 200
1515

1616
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17-
1817
dataloader = get_dataloader(settings.DATA_PATH, BATCH_SIZE)
19-
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
2018

21-
if load_trained_model:
22-
trained_model = load_model(model, settings.PATH_SAVED_MODEL, device=device)
19+
if ae_type == 'ae':
20+
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
21+
elif ae_type == 'conv':
22+
model = ConvolutionalAutoencoder().to(device)
23+
elif ae_type == 'vae':
24+
model = VariationalAutoencoder().to(device)
25+
elif ae_type == 'conv_vae':
26+
model = ConvolutionalVAE().to(device)
2327
else:
24-
trained_model = train_autoencoder(model, dataloader, NUM_EPOCHS, device=device)
28+
raise ValueError(f"Unknown AE type: {ae_type}")
29+
30+
optimizer = torch.optim.Adam(model.parameters())
31+
32+
start_epoch = 0
33+
if os.path.exists(settings.PATH_SAVED_MODEL):
34+
model, optimizer, start_epoch = load_checkpoint(
35+
model, optimizer, settings.PATH_SAVED_MODEL, device
36+
)
37+
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")
38+
39+
if not load_trained_model:
40+
train_autoencoder(
41+
model,
42+
dataloader,
43+
num_epochs=NUM_EPOCHS,
44+
device=device,
45+
start_epoch=start_epoch,
46+
optimizer=optimizer,
47+
ae_type=ae_type
48+
)
49+
print(f"Training complete up to epoch {NUM_EPOCHS}!")
2550

2651
valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)
27-
28-
save_path = os.path.join('./', settings.PATH_SAVED_MODEL)
29-
save_model(trained_model, save_path)
30-
print(f"Model saved to {save_path}")
31-
32-
avg_valid_loss = evaluate_autoencoder(trained_model, valid_dataloader, device)
52+
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
3353
print(f"Average validation loss: {avg_valid_loss:.4f}")
34-
35-
visualize_reconstructions(trained_model, valid_dataloader, num_samples=10, device=device)
54+
55+
visualize_reconstructions(
56+
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type
57+
)
3658

3759

3860
if __name__ == "__main__":
39-
main(False)
61+
main(False, ae_type='conv_vae')

utils/trainer.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,86 @@
22
import torch
33
import torch.optim as optim
44
import torch.nn as nn
5-
from torchvision import transforms
5+
import torch.nn.functional as F
66
from torchvision.utils import save_image, make_grid
77
import matplotlib.pyplot as plt
8-
from PIL import Image
98

109

11-
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu'):
10+
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu', start_epoch=0, optimizer=None, ae_type='ae'):
1211
criterion = nn.MSELoss()
13-
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
12+
if optimizer is None:
13+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
1414

15-
for epoch in range(num_epochs):
15+
for epoch in range(start_epoch, num_epochs):
1616
for data in dataloader:
1717
img = data.to(device)
18-
img = img.view(img.size(0), -1)
19-
output = model(img)
20-
loss = criterion(output, img)
18+
19+
if ae_type not in ['conv', 'conv_vae']:
20+
img = img.view(img.size(0), -1)
21+
22+
if ae_type in ['vae', 'conv_vae']:
23+
recon_x, mu, log_var = model(img)
24+
loss = loss_function_vae(recon_x, img, mu, log_var)
25+
else:
26+
output = model(img)
27+
loss = criterion(output, img)
2128

2229
optimizer.zero_grad()
2330
loss.backward()
2431
optimizer.step()
2532

2633
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
34+
save_checkpoint(model, optimizer, epoch, './autoencoder_checkpoint.pth')
2735

2836
return model
2937

3038

31-
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples"):
39+
def loss_function_vae(recon_x, x, mu, log_var):
40+
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
41+
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
42+
return BCE + KLD
43+
44+
45+
def evaluate_autoencoder(model, dataloader, device, ae_type):
46+
model.eval()
47+
total_loss = 0
48+
criterion = nn.MSELoss()
49+
with torch.no_grad():
50+
for data in dataloader:
51+
img = data.to(device)
52+
53+
if ae_type not in ['conv', 'conv_vae']:
54+
img = img.view(img.size(0), -1)
55+
56+
if ae_type in ['vae', 'conv_vae']:
57+
output, _, _ = model(img)
58+
else:
59+
output = model(img)
60+
loss = criterion(output, img)
61+
total_loss += loss.item()
62+
63+
return total_loss / len(dataloader)
64+
65+
66+
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples", ae_type='ae'):
3267
model.eval()
3368
samples = next(iter(dataloader))
3469
samples = samples[:num_samples].to(device)
35-
samples = samples.view(samples.size(0), -1)
36-
reconstructions = model(samples)
70+
71+
if ae_type not in ['conv', 'conv_vae']:
72+
samples = samples.view(samples.size(0), -1)
73+
74+
if ae_type in ['vae', 'conv_vae']:
75+
reconstructions, _, _ = model(samples)
76+
else:
77+
reconstructions = model(samples)
3778

3879
samples = samples.view(-1, 3, 64, 64)
3980
reconstructions = reconstructions.view(-1, 3, 64, 64)
4081

41-
# Combine as amostras e reconstruções em uma única grade
4282
combined = torch.cat([samples, reconstructions], dim=0)
4383
grid_img = make_grid(combined, nrow=num_samples)
4484

45-
# Visualização usando Matplotlib
4685
plt.imshow(grid_img.permute(1, 2, 0).cpu().detach().numpy())
4786
plt.axis('off')
4887
plt.show()
@@ -62,15 +101,18 @@ def load_model(model, path, device):
62101
return model
63102

64103

65-
def evaluate_autoencoder(model, dataloader, device):
66-
model.eval()
67-
total_loss = 0
68-
criterion = nn.MSELoss()
69-
with torch.no_grad():
70-
for data in dataloader:
71-
img = data.to(device)
72-
img = img.view(img.size(0), -1)
73-
output = model(img)
74-
loss = criterion(output, img)
75-
total_loss += loss.item()
76-
return total_loss / len(dataloader)
104+
def save_checkpoint(model, optimizer, epoch, path):
105+
checkpoint = {
106+
'epoch': epoch,
107+
'model_state_dict': model.state_dict(),
108+
'optimizer_state_dict': optimizer.state_dict(),
109+
}
110+
torch.save(checkpoint, path)
111+
112+
113+
def load_checkpoint(model, optimizer, path, device):
114+
checkpoint = torch.load(path, map_location=device)
115+
model.load_state_dict(checkpoint['model_state_dict'])
116+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
117+
epoch = checkpoint['epoch']
118+
return model, optimizer, epoch + 1

0 commit comments

Comments
 (0)