-
Notifications
You must be signed in to change notification settings - Fork 14
/
bdmc.py
122 lines (97 loc) · 3.79 KB
/
bdmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import itertools
import random
import numpy as np
import torch
import ais
import simulate
import vae
def bdmc(model, loader, forward_schedule, n_sample):
"""Bidirectional Monte Carlo.
Backward schedule is set to be the reverse of the forward schedule.
Args:
model (vae.VAE): VAE model
loader (iterator): iterator to loop over pairs of Variables; the first
entry being `x`, the second being `z` sampled from the *true*
posterior `p(z|x)`
forward_schedule: forward temperature schedule
n_sample (int): number of importance samples
Returns:
two lists for forward and backward bounds on batches of data
"""
# iterator is exhaustible in py3, so need duplicate
loader_forward, loader_backward = itertools.tee(loader, 2)
# forward chain
forward_logws = ais.ais_trajectory(
model,
loader_forward,
forward=True,
schedule=forward_schedule,
n_sample=n_sample,
device=device,
)
# backward chain
backward_schedule = torch.flip(forward_schedule, dims=(0,)).contiguous()
backward_logws = ais.ais_trajectory(
model,
loader_backward,
forward=False,
schedule=backward_schedule,
n_sample=n_sample,
device=device,
)
upper_bounds = []
lower_bounds = []
for i, (forward, backward) in enumerate(zip(forward_logws, backward_logws)):
lower_bounds.append(forward.mean().detach().item())
upper_bounds.append(backward.mean().detach().item())
upper_bounds = float(np.mean(upper_bounds))
lower_bounds = float(np.mean(lower_bounds))
print(
f"Average bounds on simulated data: lower {lower_bounds:.4f}, upper {upper_bounds:.4f}"
)
return forward_logws, backward_logws
def main():
model = vae.VAE(latent_dim=args.latent_dim).to(device).eval()
model.load_state_dict(torch.load(args.ckpt_path)['state_dict'])
# bdmc uses simulated data from the model
loader = simulate.simulate_data(
model,
batch_size=args.batch_size,
n_batch=args.n_batch,
device=device
)
# run bdmc
# Note: a linear schedule is used here for demo; a sigmoidal schedule might
# be advantageous in certain settings, see Section 6 in the original paper
# for more https://arxiv.org/pdf/1511.02543.pdf
forward_schedule = torch.linspace(0, 1, args.chain_length, device=device)
bdmc(
model,
loader,
forward_schedule=forward_schedule,
n_sample=args.iwae_samples,
)
def manual_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BDMC')
parser.add_argument('--latent-dim', type=int, default=50, metavar='D',
help='number of latent variables (default: 50)')
parser.add_argument('--batch-size', type=int, default=10, metavar='N',
help='number of examples to eval at once (default: 10)')
parser.add_argument('--n-batch', type=int, default=10, metavar='B',
help='number of batches to eval in total (default: 10)')
parser.add_argument('--chain-length', type=int, default=500, metavar='L',
help='length of ais chain (default: 500)')
parser.add_argument('--iwae-samples', type=int, default=100, metavar='I',
help='number of iwae samples (default: 100)')
parser.add_argument('--ckpt-path', type=str, default='checkpoints/model.pth',
metavar='C', help='path to checkpoint')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
main()