Skip to content

Commit 7e3f57e

Browse files
Updated Variational AutoEncoder example for Keras 3 (keras-team#1836)
* vae keras 3 example updated * seed generator added to rng layer * generated files are added
1 parent 62f5911 commit 7e3f57e

File tree

3 files changed

+52
-39
lines changed

3 files changed

+52
-39
lines changed

examples/generative/ipynb/vae.ipynb

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
1212
"**Date created:** 2020/05/03<br>\n",
13-
"**Last modified:** 2023/11/22<br>\n",
13+
"**Last modified:** 2024/04/24<br>\n",
1414
"**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits."
1515
]
1616
},
@@ -25,7 +25,7 @@
2525
},
2626
{
2727
"cell_type": "code",
28-
"execution_count": 0,
28+
"execution_count": null,
2929
"metadata": {
3030
"colab_type": "code"
3131
},
@@ -38,6 +38,7 @@
3838
"import numpy as np\n",
3939
"import tensorflow as tf\n",
4040
"import keras\n",
41+
"from keras import ops\n",
4142
"from keras import layers"
4243
]
4344
},
@@ -52,7 +53,7 @@
5253
},
5354
{
5455
"cell_type": "code",
55-
"execution_count": 0,
56+
"execution_count": null,
5657
"metadata": {
5758
"colab_type": "code"
5859
},
@@ -62,13 +63,16 @@
6263
"class Sampling(layers.Layer):\n",
6364
" \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
6465
"\n",
66+
" def __init__(self, **kwargs):\n",
67+
" super().__init__(**kwargs)\n",
68+
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
69+
"\n",
6570
" def call(self, inputs):\n",
6671
" z_mean, z_log_var = inputs\n",
67-
" batch = tf.shape(z_mean)[0]\n",
68-
" dim = tf.shape(z_mean)[1]\n",
69-
" epsilon = tf.random.normal(shape=(batch, dim))\n",
70-
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
71-
""
72+
" batch = ops.shape(z_mean)[0]\n",
73+
" dim = ops.shape(z_mean)[1]\n",
74+
" epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n",
75+
" return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n"
7276
]
7377
},
7478
{
@@ -82,7 +86,7 @@
8286
},
8387
{
8488
"cell_type": "code",
85-
"execution_count": 0,
89+
"execution_count": null,
8690
"metadata": {
8791
"colab_type": "code"
8892
},
@@ -113,7 +117,7 @@
113117
},
114118
{
115119
"cell_type": "code",
116-
"execution_count": 0,
120+
"execution_count": null,
117121
"metadata": {
118122
"colab_type": "code"
119123
},
@@ -140,7 +144,7 @@
140144
},
141145
{
142146
"cell_type": "code",
143-
"execution_count": 0,
147+
"execution_count": null,
144148
"metadata": {
145149
"colab_type": "code"
146150
},
@@ -170,14 +174,14 @@
170174
" with tf.GradientTape() as tape:\n",
171175
" z_mean, z_log_var, z = self.encoder(data)\n",
172176
" reconstruction = self.decoder(z)\n",
173-
" reconstruction_loss = tf.reduce_mean(\n",
174-
" tf.reduce_sum(\n",
177+
" reconstruction_loss = ops.mean(\n",
178+
" ops.sum(\n",
175179
" keras.losses.binary_crossentropy(data, reconstruction),\n",
176180
" axis=(1, 2),\n",
177181
" )\n",
178182
" )\n",
179-
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
180-
" kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))\n",
183+
" kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))\n",
184+
" kl_loss = ops.mean(ops.sum(kl_loss, axis=1))\n",
181185
" total_loss = reconstruction_loss + kl_loss\n",
182186
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
183187
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
@@ -188,8 +192,7 @@
188192
" \"loss\": self.total_loss_tracker.result(),\n",
189193
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
190194
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
191-
" }\n",
192-
""
195+
" }\n"
193196
]
194197
},
195198
{
@@ -203,7 +206,7 @@
203206
},
204207
{
205208
"cell_type": "code",
206-
"execution_count": 0,
209+
"execution_count": null,
207210
"metadata": {
208211
"colab_type": "code"
209212
},
@@ -229,7 +232,7 @@
229232
},
230233
{
231234
"cell_type": "code",
232-
"execution_count": 0,
235+
"execution_count": null,
233236
"metadata": {
234237
"colab_type": "code"
235238
},
@@ -286,7 +289,7 @@
286289
},
287290
{
288291
"cell_type": "code",
289-
"execution_count": 0,
292+
"execution_count": null,
290293
"metadata": {
291294
"colab_type": "code"
292295
},
@@ -340,4 +343,4 @@
340343
},
341344
"nbformat": 4,
342345
"nbformat_minor": 0
343-
}
346+
}

examples/generative/md/vae.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**Author:** [fchollet](https://twitter.com/fchollet)<br>
44
**Date created:** 2020/05/03<br>
5-
**Last modified:** 2023/11/22<br>
5+
**Last modified:** 2024/04/24<br>
66
**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
77

88

@@ -22,6 +22,7 @@ os.environ["KERAS_BACKEND"] = "tensorflow"
2222
import numpy as np
2323
import tensorflow as tf
2424
import keras
25+
from keras import ops
2526
from keras import layers
2627
```
2728

@@ -34,12 +35,16 @@ from keras import layers
3435
class Sampling(layers.Layer):
3536
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
3637

38+
def __init__(self, **kwargs):
39+
super().__init__(**kwargs)
40+
self.seed_generator = keras.random.SeedGenerator(1337)
41+
3742
def call(self, inputs):
3843
z_mean, z_log_var = inputs
39-
batch = tf.shape(z_mean)[0]
40-
dim = tf.shape(z_mean)[1]
41-
epsilon = tf.random.normal(shape=(batch, dim))
42-
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
44+
batch = ops.shape(z_mean)[0]
45+
dim = ops.shape(z_mean)[1]
46+
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
47+
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
4348

4449
```
4550

@@ -204,14 +209,14 @@ class VAE(keras.Model):
204209
with tf.GradientTape() as tape:
205210
z_mean, z_log_var, z = self.encoder(data)
206211
reconstruction = self.decoder(z)
207-
reconstruction_loss = tf.reduce_mean(
208-
tf.reduce_sum(
212+
reconstruction_loss = ops.mean(
213+
ops.sum(
209214
keras.losses.binary_crossentropy(data, reconstruction),
210215
axis=(1, 2),
211216
)
212217
)
213-
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
214-
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
218+
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
219+
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
215220
total_loss = reconstruction_loss + kl_loss
216221
grads = tape.gradient(total_loss, self.trainable_weights)
217222
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

examples/generative/vae.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Variational AutoEncoder
33
Author: [fchollet](https://twitter.com/fchollet)
44
Date created: 2020/05/03
5-
Last modified: 2023/11/22
5+
Last modified: 2024/04/24
66
Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
77
Accelerator: GPU
88
"""
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
import keras
21+
from keras import ops
2122
from keras import layers
2223

2324
"""
@@ -28,12 +29,16 @@
2829
class Sampling(layers.Layer):
2930
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
3031

32+
def __init__(self, **kwargs):
33+
super().__init__(**kwargs)
34+
self.seed_generator = keras.random.SeedGenerator(1337)
35+
3136
def call(self, inputs):
3237
z_mean, z_log_var = inputs
33-
batch = tf.shape(z_mean)[0]
34-
dim = tf.shape(z_mean)[1]
35-
epsilon = tf.random.normal(shape=(batch, dim))
36-
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
38+
batch = ops.shape(z_mean)[0]
39+
dim = ops.shape(z_mean)[1]
40+
epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
41+
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
3742

3843

3944
"""
@@ -94,14 +99,14 @@ def train_step(self, data):
9499
with tf.GradientTape() as tape:
95100
z_mean, z_log_var, z = self.encoder(data)
96101
reconstruction = self.decoder(z)
97-
reconstruction_loss = tf.reduce_mean(
98-
tf.reduce_sum(
102+
reconstruction_loss = ops.mean(
103+
ops.sum(
99104
keras.losses.binary_crossentropy(data, reconstruction),
100105
axis=(1, 2),
101106
)
102107
)
103-
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
104-
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
108+
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
109+
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
105110
total_loss = reconstruction_loss + kl_loss
106111
grads = tape.gradient(total_loss, self.trainable_weights)
107112
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

0 commit comments

Comments
 (0)