We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 974987e commit 4127ba0Copy full SHA for 4127ba0
autoencoders/variational_autoencoder.py
@@ -16,7 +16,11 @@ def __init__(self, n_features=24, z_dim=15):
16
self.en1 = nn.Linear(n_features, 200)
17
self.en2 = nn.Linear(200, 100)
18
self.en3 = nn.Linear(100, 50)
19
- self.en4 = nn.Linear(50, z_dim)
+
20
+ # distribution parameters
21
+ self.fc_mu = nn.Linear(50, z_dim)
22
+ self.fc_logvar = nn.Linear(50, z_dim)
23
24
self.de1 = nn.Linear(z_dim, 50)
25
self.de2 = nn.Linear(50, 100)
26
self.de3 = nn.Linear(100, 200)
@@ -29,7 +33,7 @@ def encode(self, x):
29
33
h1 = F.leaky_relu(self.en1(x))
30
34
h2 = F.leaky_relu(self.en2(h1))
31
35
h3 = F.leaky_relu(self.en3(h2))
32
- return self.en4(h3), self.en4(h3)
36
+ return self.fc_mu(h3), self.fc_logvar(h3)
37
38
def reparameterize(self, mu, logvar):
39
std = torch.exp(0.5 * logvar)
main.py
@@ -21,7 +21,7 @@
ap.add_argument('-cn', '--custom_norm', type=bool, default=False,
help='Whether to normalize all variables with min_max scaler or also use custom normalization for 4-momentum')
- ap.add_argument('-vae', '--use_vae', type=bool, default=False,
+ ap.add_argument('-vae', '--use_vae', type=bool, default=True,
help='Whether to use Variational AE')
ap.add_argument('-sae', '--use_sae', type=bool, default=False,
27
help='Whether to use Sparse AE')
0 commit comments