|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +from pathlib import Path |
| 5 | +from datetime import datetime |
| 6 | +import torchvision |
| 7 | + |
| 8 | +from vae.model import VAE |
| 9 | +from vae.main import parser, args |
| 10 | + |
| 11 | + |
| 12 | +class Sampler: |
| 13 | + def __init__(self, sample_path='vae/samples', ext='.jpg'): |
| 14 | + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| 15 | + self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device) |
| 16 | + self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model']) |
| 17 | + self.vae.eval() |
| 18 | + Path(sample_path).mkdir(parents=True, exist_ok=True) |
| 19 | + |
| 20 | + self.sample_path = sample_path |
| 21 | + self.ext = ext |
| 22 | + |
| 23 | + def sample(self): |
| 24 | + with torch.no_grad(): |
| 25 | + samples = self.vae.module.sample(num_samples=args.sample_size) |
| 26 | + torchvision.utils.save_image( |
| 27 | + samples, |
| 28 | + self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext) |
| 29 | + |
| 30 | + def generate_walk_z(self): |
| 31 | + z = torch.randn(args.z_dim, device=self.device) |
| 32 | + z = z.repeat(args.sample_size).view(args.sample_size, args.z_dim) |
| 33 | + walk_dim = np.random.choice(list(range(args.z_dim))) |
| 34 | + z[:, walk_dim] = torch.linspace(-2, 2, args.sample_size) |
| 35 | + return z |
| 36 | + |
| 37 | + def walk(self): |
| 38 | + z = self.generate_walk_z() |
| 39 | + with torch.no_grad(): |
| 40 | + samples = self.vae.module.sample(z=z) |
| 41 | + torchvision.utils.save_image( |
| 42 | + samples, |
| 43 | + self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext) |
| 44 | + |
| 45 | + |
| 46 | +if __name__ == '__main__': |
| 47 | + sampler = Sampler() |
| 48 | + if args.sample: |
| 49 | + sampler.sample() |
| 50 | + if args.walk: |
| 51 | + sampler.walk() |
0 commit comments