Skip to content

Commit a50ae95

Browse files
authored
Merge pull request #1 from loliconce/loliconce-patch-1
Add test function
2 parents 558557c + 0d64ccb commit a50ae95

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

ML/Pytorch/GANs/CycleGAN/test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
import config
3+
from tqdm import tqdm
4+
import torch.optim as optim
5+
from torch.utils.data import DataLoader
6+
from torchvision.utils import save_image
7+
from dataset import HorseZebraDataset
8+
from generator_model import Generator
9+
from utils import load_checkpoint
10+
11+
12+
13+
def test_fn(gen_Z, gen_H, loader):
14+
15+
loop = tqdm(loader, leave=True)
16+
17+
for idx, (zebra, horse) in enumerate(loop):
18+
zebra = zebra.to(config.DEVICE)
19+
horse = horse.to(config.DEVICE)
20+
21+
with torch.cuda.amp.autocast():
22+
fake_horse = gen_H(zebra)
23+
fake_zebra = gen_Z(horse)
24+
25+
save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
26+
save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
27+
28+
def main():
29+
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
30+
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
31+
32+
33+
opt_gen = optim.Adam(
34+
list(gen_Z.parameters()) + list(gen_H.parameters()),
35+
lr=config.LEARNING_RATE,
36+
betas=(0.5, 0.999),
37+
)
38+
load_checkpoint(
39+
config.CHECKPOINT_GEN_H,
40+
gen_H,
41+
opt_gen,
42+
config.LEARNING_RATE,
43+
)
44+
load_checkpoint(
45+
config.CHECKPOINT_GEN_Z,
46+
gen_Z,
47+
opt_gen,
48+
config.LEARNING_RATE,
49+
)
50+
51+
val_dataset = HorseZebraDataset(
52+
root_horse=config.VAL_DIR + "/testA",
53+
root_zebra=config.VAL_DIR + "/testB",
54+
transform=config.transforms,
55+
)
56+
57+
loader = DataLoader(
58+
val_dataset,
59+
batch_size=config.BATCH_SIZE,
60+
shuffle=False,
61+
num_workers=config.NUM_WORKERS,
62+
pin_memory=True,
63+
)
64+
test_fn(gen_Z, gen_H, loader)
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)