1
+ import torch
2
+ import config
3
+ from tqdm import tqdm
4
+ import torch .optim as optim
5
+ from torch .utils .data import DataLoader
6
+ from torchvision .utils import save_image
7
+ from dataset import HorseZebraDataset
8
+ from generator_model import Generator
9
+ from utils import load_checkpoint
10
+
11
+
12
+
13
+ def test_fn (gen_Z , gen_H , loader ):
14
+
15
+ loop = tqdm (loader , leave = True )
16
+
17
+ for idx , (zebra , horse ) in enumerate (loop ):
18
+ zebra = zebra .to (config .DEVICE )
19
+ horse = horse .to (config .DEVICE )
20
+
21
+ with torch .cuda .amp .autocast ():
22
+ fake_horse = gen_H (zebra )
23
+ fake_zebra = gen_Z (horse )
24
+
25
+ save_image (fake_horse * 0.5 + 0.5 , f"saved_images/horse_{ idx } .png" )
26
+ save_image (fake_zebra * 0.5 + 0.5 , f"saved_images/zebra_{ idx } .png" )
27
+
28
+ def main ():
29
+ gen_Z = Generator (img_channels = 3 , num_residuals = 9 ).to (config .DEVICE )
30
+ gen_H = Generator (img_channels = 3 , num_residuals = 9 ).to (config .DEVICE )
31
+
32
+
33
+ opt_gen = optim .Adam (
34
+ list (gen_Z .parameters ()) + list (gen_H .parameters ()),
35
+ lr = config .LEARNING_RATE ,
36
+ betas = (0.5 , 0.999 ),
37
+ )
38
+ load_checkpoint (
39
+ config .CHECKPOINT_GEN_H ,
40
+ gen_H ,
41
+ opt_gen ,
42
+ config .LEARNING_RATE ,
43
+ )
44
+ load_checkpoint (
45
+ config .CHECKPOINT_GEN_Z ,
46
+ gen_Z ,
47
+ opt_gen ,
48
+ config .LEARNING_RATE ,
49
+ )
50
+
51
+ val_dataset = HorseZebraDataset (
52
+ root_horse = config .VAL_DIR + "/testA" ,
53
+ root_zebra = config .VAL_DIR + "/testB" ,
54
+ transform = config .transforms ,
55
+ )
56
+
57
+ loader = DataLoader (
58
+ val_dataset ,
59
+ batch_size = config .BATCH_SIZE ,
60
+ shuffle = False ,
61
+ num_workers = config .NUM_WORKERS ,
62
+ pin_memory = True ,
63
+ )
64
+ test_fn (gen_Z , gen_H , loader )
65
+
66
+ if __name__ == "__main__" :
67
+ main ()
0 commit comments