-
Notifications
You must be signed in to change notification settings - Fork 85
/
made.py
144 lines (114 loc) · 5.83 KB
/
made.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Implements Masked AutoEncoder for Density Estimation, by Germain et al. 2015
Re-implementation by Andrej Karpathy based on https://arxiv.org/abs/1502.03509
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------------------------------------------------------
class MaskedLinear(nn.Linear):
""" same as Linear except has a configurable mask on the weights """
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.register_buffer('mask', torch.ones(out_features, in_features))
def set_mask(self, mask):
self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))
def forward(self, input):
return F.linear(input, self.mask * self.weight, self.bias)
class MADE(nn.Module):
def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False):
"""
nin: integer; number of inputs
hidden sizes: a list of integers; number of units in hidden layers
nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution
note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin
will be all the means and the second nin will be stds. i.e. output dimensions depend on the
same input dimensions in "chunks" and should be carefully decoded downstream appropriately.
the output of running the tests for this file makes this a bit more clear with examples.
num_masks: can be used to train ensemble over orderings/connections
natural_ordering: force natural ordering of dimensions, don't use random permutations
"""
super().__init__()
self.nin = nin
self.nout = nout
self.hidden_sizes = hidden_sizes
assert self.nout % self.nin == 0, "nout must be integer multiple of nin"
# define a simple MLP neural net
self.net = []
hs = [nin] + hidden_sizes + [nout]
for h0,h1 in zip(hs, hs[1:]):
self.net.extend([
MaskedLinear(h0, h1),
nn.ReLU(),
])
self.net.pop() # pop the last ReLU for the output layer
self.net = nn.Sequential(*self.net)
# seeds for orders/connectivities of the model ensemble
self.natural_ordering = natural_ordering
self.num_masks = num_masks
self.seed = 0 # for cycling through num_masks orderings
self.m = {}
self.update_masks() # builds the initial self.m connectivity
# note, we could also precompute the masks and cache them, but this
# could get memory expensive for large number of masks.
def update_masks(self):
if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency
L = len(self.hidden_sizes)
# fetch the next seed and construct a random stream
rng = np.random.RandomState(self.seed)
self.seed = (self.seed + 1) % self.num_masks
# sample the order of the inputs and the connectivity of all neurons
self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin)
for l in range(L):
self.m[l] = rng.randint(self.m[l-1].min(), self.nin-1, size=self.hidden_sizes[l])
# construct the mask matrices
masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)]
masks.append(self.m[L-1][:,None] < self.m[-1][None,:])
# handle the case where nout = nin * k, for integer k > 1
if self.nout > self.nin:
k = int(self.nout / self.nin)
# replicate the mask across the other outputs
masks[-1] = np.concatenate([masks[-1]]*k, axis=1)
# set the masks in all MaskedLinear layers
layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
for l,m in zip(layers, masks):
l.set_mask(m)
def forward(self, x):
return self.net(x)
# ------------------------------------------------------------------------------
if __name__ == '__main__':
from torch.autograd import Variable
# run a quick and dirty test for the autoregressive property
D = 10
rng = np.random.RandomState(14)
x = (rng.rand(1, D) > 0.5).astype(np.float32)
configs = [
(D, [], D, False), # test various hidden sizes
(D, [200], D, False),
(D, [200, 220], D, False),
(D, [200, 220, 230], D, False),
(D, [200, 220], D, True), # natural ordering test
(D, [200, 220], 2*D, True), # test nout > nin
(D, [200, 220], 3*D, False), # test nout > nin
]
for nin, hiddens, nout, natural_ordering in configs:
print("checking nin %d, hiddens %s, nout %d, natural %s" %
(nin, hiddens, nout, natural_ordering))
model = MADE(nin, hiddens, nout, natural_ordering=natural_ordering)
# run backpropagation for each dimension to compute what other
# dimensions it depends on.
res = []
for k in range(nout):
xtr = Variable(torch.from_numpy(x), requires_grad=True)
xtrhat = model(xtr)
loss = xtrhat[0,k]
loss.backward()
depends = (xtr.grad[0].numpy() != 0).astype(np.uint8)
depends_ix = list(np.where(depends)[0])
isok = k % nin not in depends_ix
res.append((len(depends_ix), k, depends_ix, isok))
# pretty print the dependencies
res.sort()
for nl, k, ix, isok in res:
print("output %2d depends on inputs: %30s : %s" % (k, ix, "OK" if isok else "NOTOK"))