Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bhneo committed Oct 1, 2019
2 parents 17dcb48 + abe2890 commit f917406
Show file tree
Hide file tree
Showing 28 changed files with 597 additions and 438 deletions.
45 changes: 11 additions & 34 deletions gan/conditional_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,29 +641,15 @@ def build(self, input_shape):
synchronization=tf_variables.VariableSynchronization.ON_READ,
initializer=self.moving_mean_initializer,
trainable=False,
aggregation=tf_variables.VariableAggregation.MEAN,
experimental_autocast=False)
aggregation=tf_variables.VariableAggregation.MEAN)
self.moving_cov = self.add_weight(shape=(dim, dim),
name='moving_variance',
synchronization=tf_variables.VariableSynchronization.ON_READ,
initializer=self.moving_cov_initializer,
trainable=False,
aggregation=tf_variables.VariableAggregation.MEAN,
experimental_autocast=False)
aggregation=tf_variables.VariableAggregation.MEAN)
self.built = True

# def _assign_moving_average(self, variable, value, momentum, inputs_size):
# with K.name_scope('AssignMovingAvg') as scope:
# with ops.colocate_with(variable):
# decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
# if decay.dtype != variable.dtype.base_dtype:
# decay = math_ops.cast(decay, variable.dtype.base_dtype)
# update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
# if inputs_size is not None:
# update_delta = array_ops.where(inputs_size > 0, update_delta,
# K.zeros_like(update_delta))
# return state_ops.assign_sub(variable, update_delta, name=scope)

def call(self, inputs, training=None):
_, w, h, c = K.int_shape(inputs)
bs = K.shape(inputs)[0]
Expand All @@ -681,25 +667,25 @@ def call(self, inputs, training=None):
if self.decomposition == 'cholesky':
def get_inv_sqrt(ff):
with tf.device('/cpu:0'):
sqrt = tf.linalg.cholesky(ff)
sqrt = tf.cholesky(ff)
inv_sqrt = tf.linalg.triangular_solve(sqrt, tf.eye(c))
return sqrt, inv_sqrt
elif self.decomposition == 'zca':
def get_inv_sqrt(ff):
with tf.device('/cpu:0'):
S, U, _ = tf.linalg.svd(ff + tf.eye(c)*self.epsilon, full_matrices=True)
D = tf.linalg.diag(tf.pow(S, -0.5))
S, U, _ = tf.svd(ff + tf.eye(c)*self.epsilon, full_matrices=True)
D = tf.diag(tf.pow(S, -0.5))
inv_sqrt = tf.matmul(tf.matmul(U, D), U, transpose_b=True)
D = tf.linalg.diag(tf.pow(S, 0.5))
D = tf.diag(tf.pow(S, 0.5))
sqrt = tf.matmul(tf.matmul(U, D), U, transpose_b=True)
return sqrt, inv_sqrt
elif self.decomposition == 'pca':
def get_inv_sqrt(ff):
with tf.device('/cpu:0'):
S, U, _ = tf.linalg.svd(ff + tf.eye(c)*self.epsilon, full_matrices=True)
D = tf.linalg.diag(tf.pow(S, -0.5))
S, U, _ = tf.svd(ff + tf.eye(c)*self.epsilon, full_matrices=True)
D = tf.diag(tf.pow(S, -0.5))
inv_sqrt = tf.matmul(D, U, transpose_b=True)
D = tf.linalg.diag(tf.pow(S, 0.5))
D = tf.diag(tf.pow(S, 0.5))
sqrt = tf.matmul(D, U, transpose_b=True)
return sqrt, inv_sqrt
else:
Expand All @@ -712,9 +698,8 @@ def train():
self.momentum),
K.moving_average_update(self.moving_cov,
ff_apr,
self.momentum)])
# self.add_update([self._assign_moving_average(self.moving_mean, m, self.momentum, None),
# self._assign_moving_average(self.moving_cov, ff_apr, self.momentum, None)])
self.momentum)],
inputs)
ff_apr_shrinked = (1 - self.epsilon) * ff_apr + tf.eye(c) * self.epsilon

if self.renorm:
Expand Down Expand Up @@ -1734,11 +1719,3 @@ def layer(x):
kernel_initializer=glorot_init, name=kwargs['name'] + '-c_part')([out, cls])
return Add()([out_u, out_c])
return layer


def test_dbn():
data = tf.random.normal([128, 16, 16, 8])
K.set_learning_phase(1)
out = DecorelationNormalization(decomposition='zca')(data, training=True)
print()

35 changes: 16 additions & 19 deletions gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class InvalidFIDException(Exception):
def create_inception_graph(pth):
"""Creates a graph from saved GraphDef file."""
# Creates graph from saved graph_def.pb.
with tf.compat.v1.gfile.FastGFile( pth, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
with tf.gfile.FastGFile( pth, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='FID_Inception_Net')
#-------------------------------------------------------------------------------
Expand All @@ -54,16 +54,15 @@ def _get_inception_layer(sess):
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None:
shape = o.get_shape().as_list()
# shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
# o._shape = tf.TensorShape(new_shape)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
shape = o.get_shape().as_list()
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
# o._shape = tf.TensorShape(new_shape)
o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
return pool3
#-------------------------------------------------------------------------------

Expand Down Expand Up @@ -228,27 +227,25 @@ def calculate_fid_given_paths(paths, inception_path):
raise RuntimeError("Invalid path: %s" % p)

create_inception_graph(str(inception_path))
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1, s1 = _handle_path(paths[0], sess)
m2, s2 = _handle_path(paths[1], sess)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value


initialized = False


def calculate_fid_given_arrays(arrays, cache_file=None):
print ("Computing FID...")
global m_true_data, s_true_data, initialized
if not initialized:
inception_path = check_or_download_inception(None)
create_inception_graph(str(inception_path))
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
initialized = True
with tf.compat.v1.Session() as sess:
with tf.Session() as sess:
if cache_file is None:
m_true_data, s_true_data = calculate_activation_statistics(arrays[0], sess)
else:
Expand Down
Loading

0 comments on commit f917406

Please sign in to comment.