-
Notifications
You must be signed in to change notification settings - Fork 90
/
test_ae.py
80 lines (67 loc) · 2.49 KB
/
test_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import time
import argparse
import torch
from tqdm.auto import tqdm
from utils.dataset import *
from utils.misc import *
from utils.data import *
from models.autoencoder import *
from evaluation import EMD_CD
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, default='./pretrained/AE_airplane.pt')
parser.add_argument('--categories', type=str_list, default=['airplane'])
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--device', type=str, default='cuda')
# Datasets and loaders
parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5')
parser.add_argument('--batch_size', type=int, default=128)
args = parser.parse_args()
# Logging
save_dir = os.path.join(args.save_dir, 'AE_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) )
if not os.path.exists(save_dir):
os.makedirs(save_dir)
logger = get_logger('test', save_dir)
for k, v in vars(args).items():
logger.info('[ARGS::%s] %s' % (k, repr(v)))
# Checkpoint
ckpt = torch.load(args.ckpt)
seed_all(ckpt['args'].seed)
# Datasets and loaders
logger.info('Loading datasets...')
test_dset = ShapeNetCore(
path=args.dataset_path,
cates=args.categories,
split='test',
scale_mode=ckpt['args'].scale_mode
)
test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=0)
# Model
logger.info('Loading model...')
model = AutoEncoder(ckpt['args']).to(args.device)
model.load_state_dict(ckpt['state_dict'])
all_ref = []
all_recons = []
for i, batch in enumerate(tqdm(test_loader)):
ref = batch['pointcloud'].to(args.device)
shift = batch['shift'].to(args.device)
scale = batch['scale'].to(args.device)
model.eval()
with torch.no_grad():
code = model.encode(ref)
recons = model.decode(code, ref.size(1), flexibility=ckpt['args'].flexibility).detach()
ref = ref * scale + shift
recons = recons * scale + shift
all_ref.append(ref.detach().cpu())
all_recons.append(recons.detach().cpu())
all_ref = torch.cat(all_ref, dim=0)
all_recons = torch.cat(all_recons, dim=0)
logger.info('Saving point clouds...')
np.save(os.path.join(save_dir, 'ref.npy'), all_ref.numpy())
np.save(os.path.join(save_dir, 'out.npy'), all_recons.numpy())
logger.info('Start computing metrics...')
metrics = EMD_CD(all_recons.to(args.device), all_ref.to(args.device), batch_size=args.batch_size)
cd, emd = metrics['MMD-CD'].item(), metrics['MMD-EMD'].item()
logger.info('CD: %.12f' % cd)
logger.info('EMD: %.12f' % emd)