Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


# Autoencoder Linear
class Autoencoder(nn.Module):
def __init__(self, input_dim, encoding_dim):
super(Autoencoder, self).__init__()
Expand All @@ -24,3 +27,127 @@ def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x


# Autoencoder Convolucional
class ConvolutionalAutoencoder(nn.Module):
def __init__(self):
super(ConvolutionalAutoencoder, self).__init__()

# Encoder
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2, return_indices=True)

# Decoder
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)

def forward(self, x):
x, idxs1 = self.pool(F.relu(self.enc1(x)))
x, idxs2 = self.pool(F.relu(self.enc2(x)))
x, idxs3 = self.pool(F.relu(self.enc3(x)))

x = F.relu(self.dec1(x))
x = F.relu(self.dec2(x))
x = torch.sigmoid(self.dec3(x))
return x


# Variational Autoencoder
class VariationalAutoencoder(nn.Module):
def __init__(self, encoding_dim=128):
super(VariationalAutoencoder, self).__init__()

# Encoder
self.enc1 = nn.Linear(3 * 64 * 64, 512)
self.enc2 = nn.Linear(512, 256)
self.enc3 = nn.Linear(256, encoding_dim)

# Latent space
self.fc_mu = nn.Linear(encoding_dim, encoding_dim)
self.fc_log_var = nn.Linear(encoding_dim, encoding_dim)

# Decoder
self.dec1 = nn.Linear(encoding_dim, encoding_dim)
self.dec2 = nn.Linear(encoding_dim, 256)
self.dec3 = nn.Linear(256, 512)
self.dec4 = nn.Linear(512, 3 * 64 * 64)

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x):
x = F.relu(self.enc1(x))
x = F.relu(self.enc2(x))
x = F.relu(self.enc3(x))

mu = self.fc_mu(x)
log_var = self.fc_log_var(x)
z = self.reparameterize(mu, log_var)

x = F.relu(self.dec1(z))
x = F.relu(self.dec2(x))
x = F.relu(self.dec3(x))
x = torch.sigmoid(self.dec4(x))

return x, mu, log_var


# Convolucional Variational Autoencoder
class ConvolutionalVAE(nn.Module):
def __init__(self):
super(ConvolutionalVAE, self).__init__()

# Encoder
self.enc1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)

self.fc_mu = nn.Linear(16 * 8 * 8, 128)
self.fc_log_var = nn.Linear(16 * 8 * 8, 128)

# Decoder
self.decoder_input = nn.Linear(128, 16 * 8 * 8)
self.dec1 = nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1)
self.dec2 = nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1)
self.dec3 = nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1)

self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x):
# Encoding
x = F.relu(self.enc1(x))
x = self.pool(x)
x = F.relu(self.enc2(x))
x = self.pool(x)
x = F.relu(self.enc3(x))
x = self.pool(x)

x = x.view(x.size(0), -1) # Flatten

mu = self.fc_mu(x)
log_var = self.fc_log_var(x)
z = self.reparameterize(mu, log_var)

# Decoding
x = self.decoder_input(z)
x = x.view(x.size(0), 16, 8, 8) # Unflatten
x = self.upsample(x)
x = F.relu(self.dec1(x))
x = self.upsample(x)
x = F.relu(self.dec2(x))
x = self.upsample(x)
x = torch.sigmoid(self.dec3(x))

return x, mu, log_var
60 changes: 41 additions & 19 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,61 @@
import os
import torch

from models.autoencoder import Autoencoder
from models.autoencoder import Autoencoder, ConvolutionalAutoencoder, VariationalAutoencoder, ConvolutionalVAE
from utils.dataloader import get_dataloader
from utils.trainer import train_autoencoder, visualize_reconstructions, save_model, load_model, evaluate_autoencoder
from utils.trainer import train_autoencoder, visualize_reconstructions, load_checkpoint, evaluate_autoencoder
from settings import settings


def main(load_trained_model):
def main(load_trained_model, ae_type='ae'):
BATCH_SIZE = 32
INPUT_DIM = 3 * 64 * 64
ENCODING_DIM = 12
NUM_EPOCHS = 1000
ENCODING_DIM = 64
NUM_EPOCHS = 200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataloader = get_dataloader(settings.DATA_PATH, BATCH_SIZE)
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)

if load_trained_model:
trained_model = load_model(model, settings.PATH_SAVED_MODEL, device=device)
if ae_type == 'ae':
model = Autoencoder(INPUT_DIM, ENCODING_DIM).to(device)
elif ae_type == 'conv':
model = ConvolutionalAutoencoder().to(device)
elif ae_type == 'vae':
model = VariationalAutoencoder().to(device)
elif ae_type == 'conv_vae':
model = ConvolutionalVAE().to(device)
else:
trained_model = train_autoencoder(model, dataloader, NUM_EPOCHS, device=device)
raise ValueError(f"Unknown AE type: {ae_type}")

optimizer = torch.optim.Adam(model.parameters())

start_epoch = 0
if os.path.exists(settings.PATH_SAVED_MODEL):
model, optimizer, start_epoch = load_checkpoint(
model, optimizer, settings.PATH_SAVED_MODEL, device
)
print(f"Loaded checkpoint and continuing training from epoch {start_epoch}.")

if not load_trained_model:
train_autoencoder(
model,
dataloader,
num_epochs=NUM_EPOCHS,
device=device,
start_epoch=start_epoch,
optimizer=optimizer,
ae_type=ae_type
)
print(f"Training complete up to epoch {NUM_EPOCHS}!")

valid_dataloader = get_dataloader(settings.VALID_DATA_PATH, BATCH_SIZE)

save_path = os.path.join('./', settings.PATH_SAVED_MODEL)
save_model(trained_model, save_path)
print(f"Model saved to {save_path}")

avg_valid_loss = evaluate_autoencoder(trained_model, valid_dataloader, device)
avg_valid_loss = evaluate_autoencoder(model, valid_dataloader, device, ae_type)
print(f"Average validation loss: {avg_valid_loss:.4f}")

visualize_reconstructions(trained_model, valid_dataloader, num_samples=10, device=device)

visualize_reconstructions(
model, valid_dataloader, num_samples=10, device=device, ae_type=ae_type
)


if __name__ == "__main__":
main(False)
main(False, ae_type='conv_vae')
92 changes: 67 additions & 25 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,86 @@
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from PIL import Image


def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu'):
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu', start_epoch=0, optimizer=None, ae_type='ae'):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
if optimizer is None:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
for epoch in range(start_epoch, num_epochs):
for data in dataloader:
img = data.to(device)
img = img.view(img.size(0), -1)
output = model(img)
loss = criterion(output, img)

if ae_type not in ['conv', 'conv_vae']:
img = img.view(img.size(0), -1)

if ae_type in ['vae', 'conv_vae']:
recon_x, mu, log_var = model(img)
loss = loss_function_vae(recon_x, img, mu, log_var)
else:
output = model(img)
loss = criterion(output, img)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
save_checkpoint(model, optimizer, epoch, './autoencoder_checkpoint.pth')

return model


def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples"):
def loss_function_vae(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD


def evaluate_autoencoder(model, dataloader, device, ae_type):
model.eval()
total_loss = 0
criterion = nn.MSELoss()
with torch.no_grad():
for data in dataloader:
img = data.to(device)

if ae_type not in ['conv', 'conv_vae']:
img = img.view(img.size(0), -1)

if ae_type in ['vae', 'conv_vae']:
output, _, _ = model(img)
else:
output = model(img)
loss = criterion(output, img)
total_loss += loss.item()

return total_loss / len(dataloader)


def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples", ae_type='ae'):
model.eval()
samples = next(iter(dataloader))
samples = samples[:num_samples].to(device)
samples = samples.view(samples.size(0), -1)
reconstructions = model(samples)

if ae_type not in ['conv', 'conv_vae']:
samples = samples.view(samples.size(0), -1)

if ae_type in ['vae', 'conv_vae']:
reconstructions, _, _ = model(samples)
else:
reconstructions = model(samples)

samples = samples.view(-1, 3, 64, 64)
reconstructions = reconstructions.view(-1, 3, 64, 64)

# Combine as amostras e reconstruções em uma única grade
combined = torch.cat([samples, reconstructions], dim=0)
grid_img = make_grid(combined, nrow=num_samples)

# Visualização usando Matplotlib
plt.imshow(grid_img.permute(1, 2, 0).cpu().detach().numpy())
plt.axis('off')
plt.show()
Expand All @@ -62,15 +101,18 @@ def load_model(model, path, device):
return model


def evaluate_autoencoder(model, dataloader, device):
model.eval()
total_loss = 0
criterion = nn.MSELoss()
with torch.no_grad():
for data in dataloader:
img = data.to(device)
img = img.view(img.size(0), -1)
output = model(img)
loss = criterion(output, img)
total_loss += loss.item()
return total_loss / len(dataloader)
def save_checkpoint(model, optimizer, epoch, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, path, device):
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
return model, optimizer, epoch + 1