Skip to content

Commit

Permalink
Merge pull request tnakae#2 from tnakae/ModelDumpLoad
Browse files Browse the repository at this point in the history
Add save/restore methods
  • Loading branch information
tnakae authored Jun 2, 2018
2 parents 797b4a0 + 301cdc8 commit 3423318
Showing 1 changed file with 99 additions and 38 deletions.
137 changes: 99 additions & 38 deletions dagmm/dagmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from dagmm.estimation_net import EstimationNet
from dagmm.gmm import GMM

from os import makedirs
from os.path import exists, join

class DAGMM:
""" Deep Autoencoding Gaussian Mixture Model.
Expand All @@ -12,6 +15,9 @@ class DAGMM:
for Unsupervised Anomaly Detection, ICLR 2018
(this is UNOFFICIAL implementation)
"""

MODEL_FILENAME = "DAGMM_model"

def __init__(self, comp_hiddens, comp_activation,
est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100,
Expand Down Expand Up @@ -62,8 +68,8 @@ def __init__(self, comp_hiddens, comp_activation,
self.lambda1 = lambda1
self.lambda2 = lambda2

# Create tensorflow session
self.sess = tf.InteractiveSession()
self.graph = None
self.sess = None

def __del__(self):
if self.sess is not None:
Expand All @@ -79,52 +85,62 @@ def fit(self, x):
"""
n_samples, n_features = x.shape

# Create Placeholder
self.input = input = tf.placeholder(
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[])
with tf.Graph().as_default() as graph:
self.graph = graph

# Create Placeholder
self.input = input = tf.placeholder(
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[])

# Build graph
z, x_dash = self.comp_net.inference(input)
gamma = self.est_net.inference(z, drop)
self.gmm.fit(z, gamma)
energy = self.gmm.energy(z)

self.x_dash = x_dash

# Build graph
z, x_dash = self.comp_net.inference(input)
gamma = self.est_net.inference(z, drop)
self.gmm.fit(z, gamma)
energy = self.gmm.energy(z)
# Loss function
loss = (self.comp_net.reconstruction_error(input, x_dash) +
self.lambda1 * tf.reduce_mean(energy) +
self.lambda2 * self.gmm.cov_diag_loss())

self.x_dash = x_dash
# Minimizer
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)

# Loss function
loss = (self.comp_net.reconstruction_error(input, x_dash) +
self.lambda1 * tf.reduce_mean(energy) +
self.lambda2 * self.gmm.cov_diag_loss())
# Number of batch
n_batch = (n_samples - 1) // self.minibatch_size + 1

# Minimizer
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)
# Create tensorflow session and initilize
init = tf.global_variables_initializer()

# Number of batch
n_batch = (n_samples - 1) // self.minibatch_size + 1
self.sess = tf.Session(graph=graph)
self.sess.run(init)

# Create tensorflow session and initilize
init = tf.global_variables_initializer()
self.sess.run(init)
# Training
for epoch in range(self.epoch_size):
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end]

# Training
for epoch in range(self.epoch_size):
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end]
self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio})

self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio})
if (epoch + 1) % 100 == 0:
loss_val = self.sess.run(loss, feed_dict={input:x, drop:0})
print(f" epoch {epoch+1}/{self.epoch_size} : loss = {loss_val:.3f}")

if (epoch + 1) % 100 == 0:
loss_val = self.sess.run(loss, feed_dict={input:x, drop:0})
print(f" epoch {epoch+1}/{self.epoch_size} : loss = {loss_val:.3f}")
# Fix GMM parameter
fix = self.gmm.fix_op()
self.sess.run(fix, feed_dict={input:x, drop:0})
self.energy = self.gmm.energy(z)

# Fix GMM parameter
fix = self.gmm.fix_op()
self.sess.run(fix, feed_dict={input:x, drop:0})
self.energy = self.gmm.energy(z)
tf.add_to_collection("save", self.input)
tf.add_to_collection("save", self.energy)

self.saver = tf.train.Saver()

def predict(self, x):
""" Calculate anormaly scores (sample energy) on samples in X.
Expand All @@ -140,5 +156,50 @@ def predict(self, x):
energies : array-like, shape (n_samples)
Calculated sample energies.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")

energies = self.sess.run(self.energy, feed_dict={self.input:x})
return energies

def save(self, fdir):
""" Save trained model to designated directory.
This method have to be called after training.
(If not, throw an exception)
Parameters
----------
fdir : str
Path of directory trained model is saved.
If not exists, it is created automatically.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")

if not exists(fdir):
makedirs(fdir)

model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path)

def restore(self, fdir):
""" Restore trained model from designated directory.
Parameters
----------
fdir : str
Path of directory trained model is saved.
"""
if not exists(fdir):
raise Exception("Model directory does not exist.")

model_path = join(fdir, self.MODEL_FILENAME)
meta_path = model_path + ".meta"

with tf.Graph().as_default() as graph:
self.graph = graph
self.sess = tf.Session(graph=graph)
self.saver = tf.train.import_meta_graph(meta_path)
self.saver.restore(self.sess, model_path)

self.input, self.energy = tf.get_collection("save")

0 comments on commit 3423318

Please sign in to comment.