diff --git a/vae/main.py b/vae/main.py index 3b330b0..5a71eb4 100644 --- a/vae/main.py +++ b/vae/main.py @@ -14,6 +14,7 @@ parser = argparse.ArgumentParser() +# training parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images') parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier') parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector') @@ -32,6 +33,10 @@ parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices') parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions') parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved') + +# for sampler +parser.add_argument('--sample', action="store_true", default=False, help='Sample from VAE') +parser.add_argument('--walk', action="store_true", default=False, help='Walk through a feature & sample') args = parser.parse_args() diff --git a/vae/model.py b/vae/model.py index bc6fc56..41e55ab 100644 --- a/vae/model.py +++ b/vae/model.py @@ -46,7 +46,7 @@ def forward(self, x): def sample(self, z=None, num_samples=50): if z is None: - z = torch.randn(num_samples, self.args.z_dim, device=self.device) + z = torch.randn(num_samples, self.args.z_dim, device=next(self.parameters()).device) z = self.projector(z).view(-1, self.args.model_dim * 8, self.args.img_size // (2**4), self.args.img_size // (2**4)) return self.decoder(z) diff --git a/vae/sample.py b/vae/sample.py new file mode 100644 index 0000000..2df126c --- /dev/null +++ b/vae/sample.py @@ -0,0 +1,51 @@ +import argparse +import torch +import numpy as np +from pathlib import Path +from datetime import datetime +import torchvision + +from vae.model import VAE +from vae.main import parser, args + + +class Sampler: + def __init__(self, sample_path='vae/samples', ext='.jpg'): + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device) + self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model']) + self.vae.eval() + Path(sample_path).mkdir(parents=True, exist_ok=True) + + self.sample_path = sample_path + self.ext = ext + + def sample(self): + with torch.no_grad(): + samples = self.vae.module.sample(num_samples=args.sample_size) + torchvision.utils.save_image( + samples, + self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext) + + def generate_walk_z(self): + z = torch.randn(args.z_dim, device=self.device) + z = z.repeat(args.sample_size).view(args.sample_size, args.z_dim) + walk_dim = np.random.choice(list(range(args.z_dim))) + z[:, walk_dim] = torch.linspace(-2, 2, args.sample_size) + return z + + def walk(self): + z = self.generate_walk_z() + with torch.no_grad(): + samples = self.vae.module.sample(z=z) + torchvision.utils.save_image( + samples, + self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext) + + +if __name__ == '__main__': + sampler = Sampler() + if args.sample: + sampler.sample() + if args.walk: + sampler.walk()