Skip to content

Commit 4127ba0

Browse files
committed
final commit
1 parent 974987e commit 4127ba0

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

autoencoders/variational_autoencoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def __init__(self, n_features=24, z_dim=15):
1616
self.en1 = nn.Linear(n_features, 200)
1717
self.en2 = nn.Linear(200, 100)
1818
self.en3 = nn.Linear(100, 50)
19-
self.en4 = nn.Linear(50, z_dim)
19+
20+
# distribution parameters
21+
self.fc_mu = nn.Linear(50, z_dim)
22+
self.fc_logvar = nn.Linear(50, z_dim)
23+
2024
self.de1 = nn.Linear(z_dim, 50)
2125
self.de2 = nn.Linear(50, 100)
2226
self.de3 = nn.Linear(100, 200)
@@ -29,7 +33,7 @@ def encode(self, x):
2933
h1 = F.leaky_relu(self.en1(x))
3034
h2 = F.leaky_relu(self.en2(h1))
3135
h3 = F.leaky_relu(self.en3(h2))
32-
return self.en4(h3), self.en4(h3)
36+
return self.fc_mu(h3), self.fc_logvar(h3)
3337

3438
def reparameterize(self, mu, logvar):
3539
std = torch.exp(0.5 * logvar)

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ap.add_argument('-cn', '--custom_norm', type=bool, default=False,
2222
help='Whether to normalize all variables with min_max scaler or also use custom normalization for 4-momentum')
2323

24-
ap.add_argument('-vae', '--use_vae', type=bool, default=False,
24+
ap.add_argument('-vae', '--use_vae', type=bool, default=True,
2525
help='Whether to use Variational AE')
2626
ap.add_argument('-sae', '--use_sae', type=bool, default=False,
2727
help='Whether to use Sparse AE')

0 commit comments

Comments
 (0)