-
Notifications
You must be signed in to change notification settings - Fork 14
/
VAE.lua
40 lines (31 loc) · 1.11 KB
/
VAE.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
require 'torch'
require 'nn'
local VAE = {}
function VAE.get_encoder(input_size, hidden_layer_size, latent_variable_size)
-- The Encoder
local encoder = nn.Sequential()
encoder:add(nn.Linear(input_size, hidden_layer_size))
encoder:add(nn.ReLU(true))
mean_logvar = nn.ConcatTable()
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size))
mean_logvar:add(nn.Linear(hidden_layer_size, latent_variable_size))
encoder:add(mean_logvar)
return encoder
end
function VAE.get_decoder(input_size, hidden_layer_size, latent_variable_size, continuous)
-- The Decoder
local decoder = nn.Sequential()
decoder:add(nn.Linear(latent_variable_size, hidden_layer_size))
decoder:add(nn.ReLU(true))
if continuous then
mean_logvar = nn.ConcatTable()
mean_logvar:add(nn.Linear(hidden_layer_size, input_size))
mean_logvar:add(nn.Linear(hidden_layer_size, input_size))
decoder:add(mean_logvar)
else
decoder:add(nn.Linear(hidden_layer_size, input_size))
decoder:add(nn.Sigmoid(true))
end
return decoder
end
return VAE