-
Notifications
You must be signed in to change notification settings - Fork 47
/
models.py
156 lines (126 loc) · 4.54 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torch import nn
from torch.nn import functional as F
class Encoder(nn.Module):
"""Maps an (x_i, y_i) pair to a representation r_i.
Parameters
----------
x_dim : int
Dimension of x values.
y_dim : int
Dimension of y values.
h_dim : int
Dimension of hidden layer.
r_dim : int
Dimension of output representation r.
"""
def __init__(self, x_dim, y_dim, h_dim, r_dim):
super(Encoder, self).__init__()
self.x_dim = x_dim
self.y_dim = y_dim
self.h_dim = h_dim
self.r_dim = r_dim
layers = [nn.Linear(x_dim + y_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, r_dim)]
self.input_to_hidden = nn.Sequential(*layers)
def forward(self, x, y):
"""
x : torch.Tensor
Shape (batch_size, x_dim)
y : torch.Tensor
Shape (batch_size, y_dim)
"""
input_pairs = torch.cat((x, y), dim=1)
return self.input_to_hidden(input_pairs)
class MuSigmaEncoder(nn.Module):
"""
Maps a representation r to mu and sigma which will define the normal
distribution from which we sample the latent variable z.
Parameters
----------
r_dim : int
Dimension of output representation r.
z_dim : int
Dimension of latent variable z.
"""
def __init__(self, r_dim, z_dim):
super(MuSigmaEncoder, self).__init__()
self.r_dim = r_dim
self.z_dim = z_dim
self.r_to_hidden = nn.Linear(r_dim, r_dim)
self.hidden_to_mu = nn.Linear(r_dim, z_dim)
self.hidden_to_sigma = nn.Linear(r_dim, z_dim)
def forward(self, r):
"""
r : torch.Tensor
Shape (batch_size, r_dim)
"""
hidden = torch.relu(self.r_to_hidden(r))
mu = self.hidden_to_mu(hidden)
# Define sigma following convention in "Empirical Evaluation of Neural
# Process Objectives" and "Attentive Neural Processes"
sigma = 0.1 + 0.9 * torch.sigmoid(self.hidden_to_sigma(hidden))
return mu, sigma
class Decoder(nn.Module):
"""
Maps target input x_target and samples z (encoding information about the
context points) to predictions y_target.
Parameters
----------
x_dim : int
Dimension of x values.
z_dim : int
Dimension of latent variable z.
h_dim : int
Dimension of hidden layer.
y_dim : int
Dimension of y values.
"""
def __init__(self, x_dim, z_dim, h_dim, y_dim):
super(Decoder, self).__init__()
self.x_dim = x_dim
self.z_dim = z_dim
self.h_dim = h_dim
self.y_dim = y_dim
layers = [nn.Linear(x_dim + z_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True),
nn.Linear(h_dim, h_dim),
nn.ReLU(inplace=True)]
self.xz_to_hidden = nn.Sequential(*layers)
self.hidden_to_mu = nn.Linear(h_dim, y_dim)
self.hidden_to_sigma = nn.Linear(h_dim, y_dim)
def forward(self, x, z):
"""
x : torch.Tensor
Shape (batch_size, num_points, x_dim)
z : torch.Tensor
Shape (batch_size, z_dim)
Returns
-------
Returns mu and sigma for output distribution. Both have shape
(batch_size, num_points, y_dim).
"""
batch_size, num_points, _ = x.size()
# Repeat z, so it can be concatenated with every x. This changes shape
# from (batch_size, z_dim) to (batch_size, num_points, z_dim)
z = z.unsqueeze(1).repeat(1, num_points, 1)
# Flatten x and z to fit with linear layer
x_flat = x.view(batch_size * num_points, self.x_dim)
z_flat = z.view(batch_size * num_points, self.z_dim)
# Input is concatenation of z with every row of x
input_pairs = torch.cat((x_flat, z_flat), dim=1)
hidden = self.xz_to_hidden(input_pairs)
mu = self.hidden_to_mu(hidden)
pre_sigma = self.hidden_to_sigma(hidden)
# Reshape output into expected shape
mu = mu.view(batch_size, num_points, self.y_dim)
pre_sigma = pre_sigma.view(batch_size, num_points, self.y_dim)
# Define sigma following convention in "Empirical Evaluation of Neural
# Process Objectives" and "Attentive Neural Processes"
sigma = 0.1 + 0.9 * F.softplus(pre_sigma)
return mu, sigma