Skip to content

Commit 1341dbb

Browse files
committed
added sample module
1 parent e0baeff commit 1341dbb

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

vae/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
parser = argparse.ArgumentParser()
17+
# training
1718
parser.add_argument('--img_channels', type=int, default=3, help='Numer of channels for images')
1819
parser.add_argument('--model_dim', type=float, default=128, help='model dimensions multiplier')
1920
parser.add_argument('--z_dim', type=float, default=100, help='dimension of random noise latent vector')
@@ -32,6 +33,10 @@
3233
parser.add_argument('--device_ids', type=list, default=[0, 1], help='List of GPU devices')
3334
parser.add_argument('--img_ext', type=str, default='.jpg', help='Image extentions')
3435
parser.add_argument('--checkpoint_dir', type=str, default='vae/model_weights', help='Path to where model weights will be saved')
36+
37+
# for sampler
38+
parser.add_argument('--sample', action="store_true", default=False, help='Sample from VAE')
39+
parser.add_argument('--walk', action="store_true", default=False, help='Walk through a feature & sample')
3540
args = parser.parse_args()
3641

3742

vae/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def forward(self, x):
4646

4747
def sample(self, z=None, num_samples=50):
4848
if z is None:
49-
z = torch.randn(num_samples, self.args.z_dim, device=self.device)
49+
z = torch.randn(num_samples, self.args.z_dim, device=next(self.parameters()).device)
5050
z = self.projector(z).view(-1, self.args.model_dim * 8,
5151
self.args.img_size // (2**4), self.args.img_size // (2**4))
5252
return self.decoder(z)

vae/sample.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
import torch
3+
import numpy as np
4+
from pathlib import Path
5+
from datetime import datetime
6+
import torchvision
7+
8+
from vae.model import VAE
9+
from vae.main import parser, args
10+
11+
12+
class Sampler:
13+
def __init__(self, sample_path='vae/samples', ext='.jpg'):
14+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15+
self.vae = torch.nn.DataParallel(VAE(args), device_ids=args.device_ids).to(self.device)
16+
self.vae.load_state_dict(torch.load(f"{args.checkpoint_dir}/VAE.pth")['model'])
17+
self.vae.eval()
18+
Path(sample_path).mkdir(parents=True, exist_ok=True)
19+
20+
self.sample_path = sample_path
21+
self.ext = ext
22+
23+
def sample(self):
24+
with torch.no_grad():
25+
samples = self.vae.module.sample(num_samples=args.sample_size)
26+
torchvision.utils.save_image(
27+
samples,
28+
self.sample_path + f'/sample_{int(datetime.now().timestamp()*1e6)}' + self.ext)
29+
30+
def generate_walk_z(self):
31+
z = torch.randn(args.z_dim, device=self.device)
32+
z = z.repeat(args.sample_size).view(args.sample_size, args.z_dim)
33+
walk_dim = np.random.choice(list(range(args.z_dim)))
34+
z[:, walk_dim] = torch.linspace(-2, 2, args.sample_size)
35+
return z
36+
37+
def walk(self):
38+
z = self.generate_walk_z()
39+
with torch.no_grad():
40+
samples = self.vae.module.sample(z=z)
41+
torchvision.utils.save_image(
42+
samples,
43+
self.sample_path + f'/walk_{int(datetime.now().timestamp()*1e6)}' + self.ext)
44+
45+
46+
if __name__ == '__main__':
47+
sampler = Sampler()
48+
if args.sample:
49+
sampler.sample()
50+
if args.walk:
51+
sampler.walk()

0 commit comments

Comments
 (0)