forked from luost26/diffusion-point-cloud
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_synth.py
89 lines (79 loc) · 3.12 KB
/
gen_synth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import time
import math
import argparse
import torch
from tqdm.auto import tqdm
from utils.dataset import *
from utils.misc import *
from utils.data import *
from models.vae_gaussian import *
from models.vae_flow import *
from models.flow import add_spectral_norm, spectral_norm_power_iteration
from evaluation import *
def normalize_point_clouds(pcs, mode, logger):
if mode is None:
logger.info('Will not normalize point clouds.')
return pcs
logger.info('Normalization mode: %s' % mode)
for i in tqdm(range(pcs.size(0)), desc='Normalize'):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, default='./pretrained/GEN_airplane.pt')
parser.add_argument('--categories', type=str_list, default=['airplane'])
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--device', type=str, default='cuda')
# Datasets and loaders
parser.add_argument('--batch_size', type=int, default=128)
# Sampling
parser.add_argument('--sample_num_points', type=int, default=2048)
parser.add_argument('--normalize', type=str, default='shape_bbox', choices=[None, 'shape_unit', 'shape_bbox'])
parser.add_argument('--seed', type=int, default=9988)
# Generation
parser.add_argument('--num', type=int, default=4000) # number of point clouds to generate
args = parser.parse_args()
# Logging
save_dir = os.path.join(args.save_dir, 'GEN_%d' % (int(time.time())))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
logger = get_logger('test', save_dir)
for k, v in vars(args).items():
logger.info('[ARGS::%s] %s' % (k, repr(v)))
# Checkpoint
ckpt = torch.load(args.ckpt, map_location=args.device)
seed_all(args.seed)
# Model
logger.info('Loading model...')
if ckpt['args'].model == 'gaussian':
model = GaussianVAE(ckpt['args']).to(args.device)
elif ckpt['args'].model == 'flow':
model = FlowVAE(ckpt['args']).to(args.device)
logger.info(repr(model))
# if ckpt['args'].spectral_norm:
# add_spectral_norm(model, logger=logger)
model.load_state_dict(ckpt['state_dict'])
# Generate Point Clouds
gen_pcs = []
for i in tqdm(range(0, math.ceil(args.num / args.batch_size)), 'Generate'):
with torch.no_grad():
z = torch.randn([args.batch_size, ckpt['args'].latent_dim]).to(args.device)
x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)[:args.num]
if args.normalize is not None:
gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger)
# Save
logger.info('Saving point clouds...')
np.save(os.path.join(save_dir, 'out.npy'), gen_pcs.numpy())