22import torch
33import torch .optim as optim
44import torch .nn as nn
5- from torchvision import transforms
5+ import torch . nn . functional as F
66from torchvision .utils import save_image , make_grid
77import matplotlib .pyplot as plt
8- from PIL import Image
98
109
11- def train_autoencoder (model , dataloader , num_epochs = 5 , learning_rate = 0.001 , device = 'cpu' ):
10+ def train_autoencoder (model , dataloader , num_epochs = 5 , learning_rate = 0.001 , device = 'cpu' , start_epoch = 0 , optimizer = None , ae_type = 'ae' ):
1211 criterion = nn .MSELoss ()
13- optimizer = optim .Adam (model .parameters (), lr = learning_rate )
12+ if optimizer is None :
13+ optimizer = optim .Adam (model .parameters (), lr = learning_rate )
1414
15- for epoch in range (num_epochs ):
15+ for epoch in range (start_epoch , num_epochs ):
1616 for data in dataloader :
1717 img = data .to (device )
18- img = img .view (img .size (0 ), - 1 )
19- output = model (img )
20- loss = criterion (output , img )
18+
19+ if ae_type not in ['conv' , 'conv_vae' ]:
20+ img = img .view (img .size (0 ), - 1 )
21+
22+ if ae_type in ['vae' , 'conv_vae' ]:
23+ recon_x , mu , log_var = model (img )
24+ loss = loss_function_vae (recon_x , img , mu , log_var )
25+ else :
26+ output = model (img )
27+ loss = criterion (output , img )
2128
2229 optimizer .zero_grad ()
2330 loss .backward ()
2431 optimizer .step ()
2532
2633 print (f'Epoch [{ epoch + 1 } /{ num_epochs } ], Loss: { loss .item ():.4f} ' )
34+ save_checkpoint (model , optimizer , epoch , './autoencoder_checkpoint.pth' )
2735
2836 return model
2937
3038
31- def visualize_reconstructions (model , dataloader , num_samples = 10 , device = 'cpu' , save_path = "./samples" ):
39+ def loss_function_vae (recon_x , x , mu , log_var ):
40+ BCE = F .binary_cross_entropy (recon_x , x , reduction = 'sum' )
41+ KLD = - 0.5 * torch .sum (1 + log_var - mu .pow (2 ) - log_var .exp ())
42+ return BCE + KLD
43+
44+
45+ def evaluate_autoencoder (model , dataloader , device , ae_type ):
46+ model .eval ()
47+ total_loss = 0
48+ criterion = nn .MSELoss ()
49+ with torch .no_grad ():
50+ for data in dataloader :
51+ img = data .to (device )
52+
53+ if ae_type not in ['conv' , 'conv_vae' ]:
54+ img = img .view (img .size (0 ), - 1 )
55+
56+ if ae_type in ['vae' , 'conv_vae' ]:
57+ output , _ , _ = model (img )
58+ else :
59+ output = model (img )
60+ loss = criterion (output , img )
61+ total_loss += loss .item ()
62+
63+ return total_loss / len (dataloader )
64+
65+
66+ def visualize_reconstructions (model , dataloader , num_samples = 10 , device = 'cpu' , save_path = "./samples" , ae_type = 'ae' ):
3267 model .eval ()
3368 samples = next (iter (dataloader ))
3469 samples = samples [:num_samples ].to (device )
35- samples = samples .view (samples .size (0 ), - 1 )
36- reconstructions = model (samples )
70+
71+ if ae_type not in ['conv' , 'conv_vae' ]:
72+ samples = samples .view (samples .size (0 ), - 1 )
73+
74+ if ae_type in ['vae' , 'conv_vae' ]:
75+ reconstructions , _ , _ = model (samples )
76+ else :
77+ reconstructions = model (samples )
3778
3879 samples = samples .view (- 1 , 3 , 64 , 64 )
3980 reconstructions = reconstructions .view (- 1 , 3 , 64 , 64 )
4081
41- # Combine as amostras e reconstruções em uma única grade
4282 combined = torch .cat ([samples , reconstructions ], dim = 0 )
4383 grid_img = make_grid (combined , nrow = num_samples )
4484
45- # Visualização usando Matplotlib
4685 plt .imshow (grid_img .permute (1 , 2 , 0 ).cpu ().detach ().numpy ())
4786 plt .axis ('off' )
4887 plt .show ()
@@ -62,15 +101,18 @@ def load_model(model, path, device):
62101 return model
63102
64103
65- def evaluate_autoencoder (model , dataloader , device ):
66- model .eval ()
67- total_loss = 0
68- criterion = nn .MSELoss ()
69- with torch .no_grad ():
70- for data in dataloader :
71- img = data .to (device )
72- img = img .view (img .size (0 ), - 1 )
73- output = model (img )
74- loss = criterion (output , img )
75- total_loss += loss .item ()
76- return total_loss / len (dataloader )
104+ def save_checkpoint (model , optimizer , epoch , path ):
105+ checkpoint = {
106+ 'epoch' : epoch ,
107+ 'model_state_dict' : model .state_dict (),
108+ 'optimizer_state_dict' : optimizer .state_dict (),
109+ }
110+ torch .save (checkpoint , path )
111+
112+
113+ def load_checkpoint (model , optimizer , path , device ):
114+ checkpoint = torch .load (path , map_location = device )
115+ model .load_state_dict (checkpoint ['model_state_dict' ])
116+ optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
117+ epoch = checkpoint ['epoch' ]
118+ return model , optimizer , epoch + 1
0 commit comments