Skip to content

Commit

Permalink
added sample module
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 4, 2020
1 parent e0baeff commit 1341dbb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
5 changes: 5 additions & 0 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


parser = argparse.ArgumentParser()
# training
parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images')
parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier')
parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')
Expand All @@ -32,6 +33,10 @@
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions')
parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved')

# for sampler
parser.add_argument('--sample', action="store_true", default=False, help='Sample from VAE')
parser.add_argument('--walk', action="store_true", default=False, help='Walk through a feature & sample')
args = parser.parse_args()


Expand Down
2 changes: 1 addition & 1 deletion vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, x):

def sample(self, z=None, num_samples=50):
if z is None:
z = torch.randn(num_samples, self.args.z_dim, device=self.device)
z = torch.randn(num_samples, self.args.z_dim, device=next(self.parameters()).device)
z = self.projector(z).view(-1, self.args.model_dim * 8,
self.args.img_size // (2**4), self.args.img_size // (2**4))
return self.decoder(z)
51 changes: 51 additions & 0 deletions vae/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import torch
import numpy as np
from pathlib import Path
from datetime import datetime
import torchvision

from vae.model import VAE
from vae.main import parser, args


class Sampler:
def __init__(self, sample_path='vae/samples', ext='.jpg'):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device)
self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model'])
self.vae.eval()
Path(sample_path).mkdir(parents=True, exist_ok=True)

self.sample_path = sample_path
self.ext = ext

def sample(self):
with torch.no_grad():
samples = self.vae.module.sample(num_samples=args.sample_size)
torchvision.utils.save_image(
samples,
self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext)

def generate_walk_z(self):
z = torch.randn(args.z_dim, device=self.device)
z = z.repeat(args.sample_size).view(args.sample_size, args.z_dim)
walk_dim = np.random.choice(list(range(args.z_dim)))
z[:, walk_dim] = torch.linspace(-2, 2, args.sample_size)
return z

def walk(self):
z = self.generate_walk_z()
with torch.no_grad():
samples = self.vae.module.sample(z=z)
torchvision.utils.save_image(
samples,
self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext)


if __name__ == '__main__':
sampler = Sampler()
if args.sample:
sampler.sample()
if args.walk:
sampler.walk()

0 comments on commit 1341dbb

Please sign in to comment.