Skip to content

Commit 7bfff2a

Browse files
committed
AE network tuning and optimization
1 parent 7acf174 commit 7bfff2a

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

autoencoders/standard_AE.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
import pandas as pd
22
import tensorflow as tf
33
import numpy as np
4-
import datetime
54
import os
5+
import time
66
import matplotlib.pyplot as plt
7-
from data_preprocessing import preprocess
87
from sklearn.utils import shuffle
98

109

1110
# Parameters
1211
input_dim = 28
1312
hidden_size1 = 100
14-
hidden_size2 = 100
13+
hidden_size2 = 50
14+
hidden_size3 = 30
1515
z_dim = 20
1616

1717
batch_size = 100
18-
n_epochs = 1000
18+
n_epochs = 2
1919
learning_rate = 0.001
2020
beta1 = 0.9
2121
results_path = './autoencoders/Results/Standard_AE'
2222
saved_model_path = results_path + '/Saved_models/'
2323

2424
# Placeholders for input data and the targets
25-
x_input = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Input')
26-
x_target = tf.placeholder(dtype=tf.float32, shape=[batch_size, input_dim], name='Target')
25+
x_input = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='Input')
26+
x_target = tf.placeholder(dtype=tf.float32, shape=[None, input_dim], name='Target')
2727
decoder_input = tf.placeholder(dtype=tf.float32, shape=[1, z_dim], name='Decoder_input')
2828

29+
recontruct = True
2930

3031
def dense(x, n1, n2, name):
3132
"""
@@ -55,9 +56,11 @@ def encoder(x, reuse=False):
5556
if reuse:
5657
tf.get_variable_scope().reuse_variables()
5758
with tf.name_scope('Encoder'):
58-
e_dense_1 = tf.nn.relu(dense(x, input_dim, hidden_size1, 'e_dense_1'))
59-
e_dense_2 = tf.nn.relu(dense(e_dense_1, hidden_size1, hidden_size2, 'e_dense_2'))
60-
latent_variable = dense(e_dense_2, hidden_size2, z_dim, 'e_latent_variable')
59+
e_dense_1 = tf.nn.tanh(dense(x, input_dim, hidden_size1, 'e_dense_1'))
60+
e_dense_2 = tf.nn.tanh(dense(e_dense_1, hidden_size1, hidden_size2, 'e_dense_2'))
61+
e_dense_3 = tf.nn.tanh(dense(e_dense_2, hidden_size2, hidden_size3, 'e_dense_3'))
62+
#latent_variable = dense(e_dense_3, hidden_size3, z_dim, 'e_latent_variable')
63+
latent_variable = tf.nn.tanh(dense(e_dense_3, hidden_size3, z_dim, 'e_latent_variable'))
6164
return latent_variable
6265

6366

@@ -72,12 +75,20 @@ def decoder(x, reuse=False):
7275
if reuse:
7376
tf.get_variable_scope().reuse_variables()
7477
with tf.name_scope('Decoder'):
75-
d_dense_1 = tf.nn.relu(dense(x, z_dim, hidden_size2, 'd_dense_1'))
76-
d_dense_2 = tf.nn.relu(dense(d_dense_1, hidden_size2, hidden_size1, 'd_dense_2'))
77-
output = tf.nn.sigmoid(dense(d_dense_2, hidden_size1, input_dim, 'd_output'))
78+
d_dense_1 = tf.nn.tanh(dense(x, z_dim, hidden_size3, 'd_dense_1'))
79+
d_dense_2 = tf.nn.tanh(dense(d_dense_1, hidden_size3, hidden_size2, 'd_dense_2'))
80+
e_dense_3 = tf.nn.tanh(dense(d_dense_2, hidden_size2, hidden_size1, 'd_dense_3'))
81+
output = tf.nn.tanh(dense(e_dense_3, hidden_size1, input_dim, 'd_output'))
7882
return output
7983

8084

85+
def reconstruct_variables(sess=None, op=None, data=None):
86+
# run the trained AE for predictions on the test data
87+
reconstructed_data = sess.run(op, feed_dict={x_input: data})
88+
print('Reconstructed data shape: {}'.format(reconstructed_data.shape))
89+
# We are going to plot the reconstructed data below
90+
91+
8192
def train(train_model=True, train_data=None, test_data=None):
8293
"""
8394
Used to train the autoencoder by passing in the necessary inputs.
@@ -103,26 +114,38 @@ def train(train_model=True, train_data=None, test_data=None):
103114
with tf.Session() as sess:
104115
sess.run(init)
105116
if train_model:
106-
117+
start = time.time()
107118
for i in range(n_epochs):
108119
train_data = shuffle(train_data)
109120
# break the train data df into chunks of size batch_size
110-
train_df = [train_data[x:x + batch_size] for x in range(0, train_data.shape[0], batch_size)]
111-
count = 0
112-
for batch in train_df:
113-
if batch.shape[0] == batch_size:
114-
count += 1
115-
sess.run(optimizer, feed_dict={x_input: batch, x_target: batch})
116-
117-
if count % 50 == 0:
118-
batch_loss = sess.run([loss], feed_dict={x_input: batch, x_target: batch})
119-
print("Loss: {}".format(batch_loss))
120-
print("Epoch: {}, iteration: {}".format(i, count))
121-
step += 1
122-
saver.save(sess, save_path=saved_model_path, global_step=step)
121+
train_batches = [train_data[x:x + batch_size] for x in range(0, train_data.shape[0], batch_size)]
122+
123+
mean_loss = 0.0
124+
for batch in train_batches:
125+
sess.run(optimizer, feed_dict={x_input: batch, x_target: batch})
126+
127+
batch_loss = sess.run([loss], feed_dict={x_input: batch, x_target: batch})
128+
mean_loss += batch_loss[0]
129+
step += 1
130+
131+
# Calculate the mean loss over all batches in one epoch
132+
mean_loss = float(mean_loss)/len(train_batches)
133+
# Saving takes a lot of time
134+
# saver.save(sess, save_path=saved_model_path, global_step=step)
123135
print("Model Trained!")
124136

125-
print("Saved Model Path: {}".format(saved_model_path))
137+
validation_loss = sess.run([loss], feed_dict={x_input: test_data, x_target: test_data})
138+
print('\n-------------------------------------------------------------\n')
139+
print('Train loss after epoch {}: {}'.format(i, mean_loss))
140+
print('Validation loss after epoch {}: {}'.format(i, validation_loss))
141+
print("Elapsed time {:.2f} sec".format(time.time() - start))
142+
print('\n-------------------------------------------------------------\n')
143+
144+
# print("Saved Model Path: {}".format(saved_model_path))
145+
146+
if recontruct == True:
147+
reconstruct_variables(sess=sess, op=decoder_output, data=test_data)
148+
126149
else:
127150
all_results = os.listdir(results_path)
128151
all_results.sort()

main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from data_preprocessing import preprocess
22
from autoencoders import standard_AE
3+
from data_loader import load_cms_data
34

4-
5+
#cms_data_df = load_cms_data(filename="open_cms_data.root")
56
train_data, test_data = preprocess()
67

78
standard_AE.train(train_data=train_data, test_data=test_data)

0 commit comments

Comments
 (0)