Skip to content

Commit 58d4e5b

Browse files
committed
bump
1 parent c1824aa commit 58d4e5b

File tree

7 files changed

+269
-91
lines changed

7 files changed

+269
-91
lines changed

cond_vaegan.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import deeppy as dp
55
import deeppy.expr as expr
66

7-
from vaegan import KLDivergence, NegativeGradient, SquareError
7+
from vaegan import KLDivergence, NegativeGradient, ScaleGradient, SquareError, WeightedParameter
88

99

1010
class AppendSpatially(expr.base.Binary):
@@ -45,21 +45,30 @@ def __call__(self, x, y):
4545
return x
4646

4747

48-
class ConditionalVAEGAN(dp.base.Model):
48+
class ConditionalVAEGAN(dp.base.Model, dp.base.CollectionMixin):
4949
def __init__(self, encoder, sampler, generator, discriminator, mode,
50-
reconstruct_error=None):
50+
reconstruct_error=None, vae_grad_scale=1.0):
5151
self.encoder = encoder
5252
self.sampler = sampler
53-
self.generator = generator
5453
self.mode = mode
5554
self.discriminator = discriminator
55+
self.vae_grad_scale = vae_grad_scale
5656
self.eps = 1e-4
5757
if reconstruct_error is None:
5858
reconstruct_error = SquareError()
5959
self.reconstruct_error = reconstruct_error
60+
generator.params = [p.parent if isinstance(p, WeightedParameter) else p
61+
for p in generator.params]
6062
if self.mode == 'vaegan':
63+
generator.params = [WeightedParameter(p, vae_grad_scale)
64+
for p in generator.params]
6165
self.generator_neg = deepcopy(generator)
6266
self.generator_neg.params = [p.share() for p in generator.params]
67+
if self.mode == 'gan':
68+
generator.params = [WeightedParameter(p, -1.0)
69+
for p in generator.params]
70+
self.generator = generator
71+
self.collection = [self.encoder, self.sampler, self.generator, self.discriminator]
6372

6473
def _embed_expr(self, x, y):
6574
h_enc = self.encoder(x, y)
@@ -81,24 +90,20 @@ def setup(self, x_shape, y_shape):
8190
z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc)
8291
self.kld = KLDivergence()(z_mu, z_log_sigma)
8392
x_tilde = self.generator(z, self.y_src)
84-
# if self.mode == 'vaegan':
85-
# x_tilde = ScaleGradient()(x_tilde)
8693
self.logpxz = self.reconstruct_error(x_tilde, self.x_src)
87-
loss = self.kld + expr.sum(self.logpxz)
94+
loss = 0.5*self.kld + expr.sum(self.logpxz)
8895

8996
if self.mode in ['gan', 'vaegan']:
9097
y = self.y_src
9198
if self.mode == 'gan':
9299
z = self.sampler.samples()
93100
x_tilde = self.generator(z, y)
94-
x_tilde = NegativeGradient()(x_tilde)
95101
gen_size = batch_size
96102
elif self.mode == 'vaegan':
97-
z = NegativeGradient()(z)
103+
z = ScaleGradient(0.0)(z)
98104
z = expr.Concatenate(axis=0)(z, z_eps)
99105
y = expr.Concatenate(axis=0)(y, self.y_src)
100106
x_tilde = self.generator_neg(z, y)
101-
x_tilde = NegativeGradient()(x_tilde)
102107
gen_size = batch_size*2
103108
x = expr.Concatenate(axis=0)(self.x_src, x_tilde)
104109
y = expr.Concatenate(axis=0)(y, self.y_src)

cond_vaegan_cifar.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ def affine(n_out, gain):
1616
return expr.nnet.Affine(n_out=n_out, weights=dp.AutoFiller(gain))
1717

1818

19-
def conv(n_filters, filter_size, gain=1.0):
19+
def conv(n_filters, filter_size, stride=1, gain=1.0):
2020
return expr.nnet.Convolution(
21-
n_filters=n_filters, strides=(1, 1), weights=dp.AutoFiller(gain),
22-
filter_shape=(filter_size, filter_size), border_mode='same',
21+
n_filters=n_filters, strides=(stride, stride),
22+
weights=dp.AutoFiller(gain), filter_shape=(filter_size, filter_size),
23+
border_mode='same',
2324
)
2425

2526

@@ -84,23 +85,22 @@ def model_expressions(img_shape):
8485
conv(n_channels, 3, gain=gain),
8586
])
8687
discriminator = cond_vaegan.ConditionalSequential([
87-
conv(32, 5, gain=gain),
88-
pool(),
88+
conv(32, 5, stride=2, gain=gain),
8989
expr.nnet.ReLU(),
9090
expr.nnet.SpatialDropout(0.2),
91-
conv(64, 5, gain=gain),
92-
pool(),
91+
conv(64, 5, stride=2, gain=gain),
9392
expr.nnet.ReLU(),
9493
expr.nnet.SpatialDropout(0.2),
9594
conv(96, 3, gain=gain),
96-
expr.nnet.ReLU(),
97-
expr.nnet.SpatialDropout(0.2),
9895
expr.Reshape((-1, 96*8*8)),
99-
expr.Concatenate(axis=1),
100-
affine(n_discriminator, gain),
101-
expr.nnet.ReLU(),
102-
expr.nnet.Dropout(0.5),
103-
affine(1, gain),
96+
# expr.nnet.ReLU(),
97+
# expr.nnet.SpatialDropout(0.2),
98+
# expr.Reshape((-1, 96*8*8)),
99+
# expr.Concatenate(axis=1),
100+
# affine(n_discriminator, gain),
101+
## expr.nnet.ReLU(),
102+
## expr.nnet.Dropout(0.5),
103+
## affine(1, gain),
104104
expr.nnet.Sigmoid(),
105105
])
106106
return encoder, sampler, generator, discriminator
@@ -112,7 +112,7 @@ def clip_range(imgs):
112112

113113
def run():
114114
mode = 'gan'
115-
experiment_name = mode
115+
experiment_name = mode + '_stride_local_discrimination'
116116
filename = 'savestates/cifar_cond_' + experiment_name + '.pickle'
117117
in_filename = filename
118118
in_filename = None
@@ -181,9 +181,12 @@ def plot():
181181

182182
# Train network
183183
runs = [
184-
(150, dp.RMSProp(learn_rate=0.1)),
185-
(150, dp.RMSProp(learn_rate=0.08)),
184+
# (10, dp.RMSProp(learn_rate=0.08)),
185+
# (25, dp.RMSProp(learn_rate=0.12)),
186+
# (100, dp.RMSProp(learn_rate=0.1)),
187+
(150, dp.RMSProp(learn_rate=0.075)),
186188
(150, dp.RMSProp(learn_rate=0.06)),
189+
(150, dp.RMSProp(learn_rate=0.05)),
187190
(150, dp.RMSProp(learn_rate=0.04)),
188191
(25, dp.RMSProp(learn_rate=0.01)),
189192
]

cond_vaegan_mnist.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@ def model_expressions(img_shape):
2121
sigma = 0.001
2222
n_in = np.prod(img_shape)
2323
n_encoder = 1024
24-
n_hidden = 64
25-
n_generator = 1024
26-
n_discriminator = 1024
24+
n_hidden = 32
25+
n_generator = 2048
26+
n_discriminator = 2048
2727

2828
encoder = cond_vaegan.ConditionalSequential([
2929
expr.Concatenate(axis=1),
3030
affine(n_encoder, gain),
31+
expr.nnet.BatchNormalization(),
3132
expr.nnet.ReLU(),
33+
expr.nnet.Dropout(0.5),
34+
expr.Concatenate(axis=1),
3235
affine(n_encoder, gain),
36+
expr.nnet.BatchNormalization(),
3337
expr.nnet.ReLU(),
3438
])
3539
sampler = vaegan.NormalSampler(
@@ -46,18 +50,21 @@ def model_expressions(img_shape):
4650
affine(n_generator, gain),
4751
expr.nnet.BatchNormalization(),
4852
expr.nnet.ReLU(),
53+
expr.Concatenate(axis=1),
4954
affine(n_in, gain),
5055
expr.nnet.Sigmoid(),
5156
])
5257
discriminator = cond_vaegan.ConditionalSequential([
53-
expr.nnet.Dropout(0.5),
5458
expr.Concatenate(axis=1),
5559
affine(n_discriminator, gain),
5660
expr.nnet.ReLU(),
57-
expr.nnet.Dropout(0.5),
61+
expr.nnet.Dropout(0.25),
5862
expr.Concatenate(axis=1),
5963
affine(n_discriminator, gain),
64+
expr.nnet.BatchNormalization(),
6065
expr.nnet.ReLU(),
66+
expr.nnet.Dropout(0.25),
67+
expr.Concatenate(axis=1),
6168
affine(1, gain),
6269
expr.nnet.Sigmoid(),
6370

@@ -71,9 +78,10 @@ def to_b01c(imgs_flat, img_shape):
7178

7279

7380
def run():
74-
mode = 'gan'
75-
experiment_name = mode
76-
filename = 'savestates/mnist_cond_' + experiment_name + '.pickle'
81+
mode = 'vaegan'
82+
vae_grad_scale = 0.025
83+
experiment_name = mode + 'scale_%.5f' % vae_grad_scale
84+
filename = 'savestates/mnist_' + experiment_name + '.pickle'
7785
in_filename = filename
7886
in_filename = None
7987
print('experiment_name', experiment_name)
@@ -95,6 +103,7 @@ def run():
95103
x_train = np.reshape(x_train, (x_train.shape[0], -1))
96104
x_test = np.reshape(x_test, (x_test.shape[0], -1))
97105

106+
98107
# Setup network
99108
if in_filename is None:
100109
print('Creating new model')
@@ -111,11 +120,12 @@ def run():
111120
generator=generator,
112121
discriminator=discriminator,
113122
mode=mode,
114-
reconstruct_error=expr.nnet.BinaryCrossEntropy()
123+
reconstruct_error=expr.nnet.BinaryCrossEntropy(),
124+
vae_grad_scale=vae_grad_scale,
115125
)
116126

117127
# Prepare network inputs
118-
batch_size = 64
128+
batch_size = 128
119129
train_input = dp.SupervisedInput(x_train, y_train, batch_size=batch_size,
120130
epoch_size=250)
121131

@@ -136,6 +146,7 @@ def run():
136146

137147
def plot():
138148
model.phase = 'test'
149+
model.sampler.batch_size=100
139150
examples_z = model.embed(examples, examples_y)
140151
examples_recon = model.reconstruct(examples_z, examples_y)
141152
recon_video.append(img_tile(to_b01c(examples_recon, img_shape)))
@@ -144,12 +155,13 @@ def plot():
144155
model.setup(**train_input.shapes)
145156
model.phase = 'train'
146157

158+
147159
# Train network
148160
runs = [
149-
(50, dp.RMSProp(learn_rate=0.3)),
150-
(150, dp.RMSProp(learn_rate=0.1)),
151-
(5, dp.RMSProp(learn_rate=0.05)),
161+
(75, dp.RMSProp(learn_rate=0.075)),
162+
(25, dp.RMSProp(learn_rate=0.05)),
152163
(5, dp.RMSProp(learn_rate=0.01)),
164+
(5, dp.RMSProp(learn_rate=0.005)),
153165
]
154166
try:
155167
for n_epochs, learn_rule in runs:
@@ -167,12 +179,71 @@ def plot():
167179
expressions = encoder, sampler, generator, discriminator
168180
pickle.dump(expressions, f)
169181

182+
model.phase = 'test'
183+
batch_size = 128
184+
model.sampler.batch_size=128
185+
z = []
186+
i = 0
187+
z = model.embed(x_train, y_train)
188+
print(z.shape)
189+
z_mean = np.mean(z, axis=0)
190+
z_std = np.std(z, axis=0)
191+
z_cov = np.cov(z.T)
192+
print(np.mean(z_mean), np.std(z_mean))
193+
print(np.mean(z_std), np.std(z_std))
194+
print(z_mean.shape, z_std.shape, z_cov.shape)
195+
196+
197+
raw_input('\n\ngenerate latent space video?\n')
170198
print('Generating latent space video')
171199
walk_video = Video('plots/mnist_' + experiment_name + '_walk.mp4')
172-
for z in random_walk(samples_z, 500, step_std=0.15):
200+
for z in random_walk(samples_z, 500, n_dir_steps=10, mean=z_mean, std=z_cov):
173201
samples = model.reconstruct(z, samples_y)
174202
walk_video.append(img_tile(to_b01c(samples, img_shape)))
175203

176204

205+
206+
print('Generating AdversarialMNIST dataset')
207+
_, y_train, _, y_test = dataset.arrays(dp_dtypes=True)
208+
n = 0
209+
batch_size = 512
210+
advmnist_size = 1e6
211+
x_advmnist = np.empty((advmnist_size, 28*28))
212+
y_advmnist = np.empty((advmnist_size,))
213+
while n < advmnist_size:
214+
samples_z = np.random.multivariate_normal(mean=z_mean, cov=z_cov,
215+
size=batch_size)
216+
samples_z = samples_z.astype(dp.float_)
217+
start_idx = n % len(y_train)
218+
stop_idx = (n + batch_size) % len(y_train)
219+
if start_idx > stop_idx:
220+
samples_y = np.concatenate([y_train[start_idx:], y_train[:stop_idx]])
221+
else:
222+
samples_y = y_train[start_idx:stop_idx]
223+
y_advmnist[n:n+batch_size] = samples_y[:advmnist_size-n]
224+
samples_y = one_hot(samples_y, n_classes).astype(dp.float_)
225+
samples = model.reconstruct(samples_z, samples_y)
226+
x_advmnist[n:n+batch_size] = samples[:advmnist_size-n]
227+
n += batch_size
228+
229+
230+
x_train = x_advmnist
231+
y_train = y_advmnist
232+
import sklearn.neighbors
233+
clf = sklearn.neighbors.KNeighborsClassifier(n_neighbors=1, algorithm='brute', n_jobs=-1)
234+
clf.fit(x_train, y_train)
235+
print('KNN predict')
236+
step = 2500
237+
errors = []
238+
i = 0
239+
while i < len(x_test):
240+
print(i)
241+
errors.append(clf.predict(x_test[i:i+step]) != y_test[i:i+step])
242+
i += step
243+
error = np.mean(errors)
244+
print('Test error rate: %.4f' % error)
245+
246+
print('DONE ' + experiment_name)
247+
177248
if __name__ == '__main__':
178249
run()

util.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77

88
def img_tile(imgs):
9-
return dp.misc.img_tile(dp.misc.img_stretch(imgs))
9+
if imgs.dtype not in [np.int_, np.uint8]:
10+
imgs = dp.misc.img_stretch(imgs)
11+
return dp.misc.img_tile(imgs)
1012

1113

1214
def plot_img(img, title, filename=None):
@@ -26,13 +28,19 @@ def one_hot(labels, n_classes):
2628
return onehot
2729

2830

29-
def random_walk(start_pos, n_steps, step_std):
31+
def random_walk(start_pos, n_steps, n_dir_steps=10, mean=0.0, std=1.0):
3032
pos = np.copy(start_pos)
3133
for i in range(n_steps):
32-
if i % 10 == 0:
33-
step = np.random.normal(scale=step_std, size=pos.shape)
34-
sign_change = np.logical_and(np.abs(pos) > 0.7,
35-
np.sign(pos) == np.sign(step))
36-
step[sign_change] *= -1
34+
if i % n_dir_steps == 0:
35+
if isinstance(mean, float):
36+
next_point = np.random.normal(
37+
scale=std, loc=mean, size=pos.shape
38+
)
39+
else:
40+
next_point = np.random.multivariate_normal(
41+
mean=mean, cov=std, size=pos.shape[0]
42+
)
43+
step = (next_point - pos)
44+
step /= n_dir_steps
3745
pos += step
3846
yield pos

0 commit comments

Comments
 (0)