Skip to content

Commit d4cd095

Browse files
feat: create trainer file
1 parent 227e7ca commit d4cd095

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

utils/__init__.py

Whitespace-only changes.

utils/trainer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import torch
3+
import torch.optim as optim
4+
import torch.nn as nn
5+
from torchvision import transforms
6+
from torchvision.utils import save_image, make_grid
7+
import matplotlib.pyplot as plt
8+
from PIL import Image
9+
10+
11+
def train_autoencoder(model, dataloader, num_epochs=5, learning_rate=0.001, device='cpu'):
12+
criterion = nn.MSELoss()
13+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
14+
15+
for epoch in range(num_epochs):
16+
for data in dataloader:
17+
img = data.to(device)
18+
img = img.view(img.size(0), -1)
19+
output = model(img)
20+
loss = criterion(output, img)
21+
22+
optimizer.zero_grad()
23+
loss.backward()
24+
optimizer.step()
25+
26+
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
27+
28+
return model
29+
30+
31+
def visualize_reconstructions(model, dataloader, num_samples=10, device='cpu', save_path="./samples"):
32+
model.eval()
33+
samples = next(iter(dataloader))
34+
samples = samples[:num_samples].to(device)
35+
samples = samples.view(samples.size(0), -1)
36+
reconstructions = model(samples)
37+
38+
samples = samples.view(-1, 3, 64, 64)
39+
reconstructions = reconstructions.view(-1, 3, 64, 64)
40+
41+
# Combine as amostras e reconstruções em uma única grade
42+
combined = torch.cat([samples, reconstructions], dim=0)
43+
grid_img = make_grid(combined, nrow=num_samples)
44+
45+
# Visualização usando Matplotlib
46+
plt.imshow(grid_img.permute(1, 2, 0).cpu().detach().numpy())
47+
plt.axis('off')
48+
plt.show()
49+
50+
if not os.path.exists(save_path):
51+
os.makedirs(save_path)
52+
save_image(grid_img, os.path.join(save_path, 'combined_samples.png'))
53+
54+
55+
def save_model(model, path):
56+
torch.save(model.state_dict(), path)
57+
58+
59+
def load_model(model, path, device):
60+
model.load_state_dict(torch.load(path, map_location=device))
61+
model.eval()
62+
return model
63+
64+
65+
def evaluate_autoencoder(model, dataloader, device):
66+
model.eval()
67+
total_loss = 0
68+
criterion = nn.MSELoss()
69+
with torch.no_grad():
70+
for data in dataloader:
71+
img = data.to(device)
72+
img = img.view(img.size(0), -1)
73+
output = model(img)
74+
loss = criterion(output, img)
75+
total_loss += loss.item()
76+
return total_loss / len(dataloader)

0 commit comments

Comments
 (0)