diff --git a/gan/conditional_layers.py b/gan/conditional_layers.py index 472d73a..9e707b4 100644 --- a/gan/conditional_layers.py +++ b/gan/conditional_layers.py @@ -634,7 +634,7 @@ def __init__(self, self.decomposition = decomposition self.iter_num = iter_num - def cov_initializer(self, shape, dtype=tf.float32, partitional_info=None): + def cov_initializer(self, shape, dtype=tf.float32, partition_info=None): moving_convs = [] for i in range(shape[0]): moving_conv = tf.expand_dims(tf.eye(shape[1], dtype=dtype), 0) @@ -716,11 +716,11 @@ def get_inv_sqrt(ff, m_per_group): projection = tf.eye(m_per_group) projection = tf.expand_dims(projection, 0) - projection = tf.tile(projection, [self.groups, 1, 1]) + projection = tf.tile(projection, [self.group, 1, 1]) for i in range(self.iter_num): projection = (3 * projection - projection * projection * projection * sigma_norm) / 2 - return projection / tf.sqrt(trace) + return None, projection / tf.sqrt(trace) else: assert False diff --git a/run.py b/run.py index 096826b..bd3724f 100644 --- a/run.py +++ b/run.py @@ -310,7 +310,7 @@ def main(): help="Layer after block normalization. ccs - conditional shift and scale." "ucs - uncoditional shift and scale. ucconv - condcoloring. ufconv - condcoloring + sa." "n - None.") - parser.add_argument("--decomposition", default='cholesky', choices=['cholesky', 'zca', 'pca', 'iter'], help='') + parser.add_argument("--decomposition", default='cholesky', choices=['cholesky', 'zca', 'pca', 'iter_norm'], help='') parser.add_argument("--group", default=1, type=int, help='') parser.add_argument("--iter_num", default=5, type=int, help='') parser.add_argument("--generator_batch_multiple", default=2, type=int, diff --git a/scripts/test/cifar10_dcgan_sn_uncond_light_iter_group.py b/scripts/test/cifar10_dcgan_sn_uncond_light_iter_group.py index f4c785c..a4bbb1d 100644 --- a/scripts/test/cifar10_dcgan_sn_uncond_light_iter_group.py +++ b/scripts/test/cifar10_dcgan_sn_uncond_light_iter_group.py @@ -5,15 +5,15 @@ os.system('python run.py --name {} --dataset cifar10 --generator_adversarial_objective hinge\ --discriminator_adversarial_objective hinge --generator_block_norm d --generator_block_after_norm uconv\ - --generator_last_norm d --decomposition iter --iter_num 5 --group 32 --generator_last_after_norm uconv --discriminator_filters 256 --generator_filters 256\ + --generator_last_norm d --decomposition iter_norm --iter_num 5 --group 32 --generator_last_after_norm uconv --discriminator_filters 512 --generator_filters 512\ --discriminator_spectral 1 --gradient_penalty_weight 0 --lr_decay_schedule linear --number_of_epochs 50 --arc dcgan --training_ratio 1 --generator_batch_multiple 1'.format(base_name)) os.system('python run.py --name {} --dataset cifar10 --generator_adversarial_objective hinge\ --discriminator_adversarial_objective hinge --generator_block_norm d --generator_block_after_norm uconv\ - --generator_last_norm d --decomposition iter --iter_num 5 --group 1 --generator_last_after_norm uconv --discriminator_filters 256 --generator_filters 256\ + --generator_last_norm d --decomposition iter_norm --iter_num 5 --group 1 --generator_last_after_norm uconv --discriminator_filters 512 --generator_filters 512\ --discriminator_spectral 1 --gradient_penalty_weight 0 --lr_decay_schedule linear --number_of_epochs 50 --arc dcgan --training_ratio 1 --generator_batch_multiple 1'.format(base_name)) os.system('python run.py --name {} --dataset cifar10 --generator_adversarial_objective hinge\ --discriminator_adversarial_objective hinge --generator_block_norm d --generator_block_after_norm uconv\ - --generator_last_norm d --decomposition iter --iter_num 5 --group 16 --generator_last_after_norm uconv --discriminator_filters 256 --generator_filters 256\ + --generator_last_norm d --decomposition iter_norm --iter_num 5 --group 16 --generator_last_after_norm uconv --discriminator_filters 512 --generator_filters 512\ --discriminator_spectral 1 --gradient_penalty_weight 0 --lr_decay_schedule linear --number_of_epochs 50 --arc dcgan --training_ratio 1 --generator_batch_multiple 1'.format(base_name))