Skip to content

Conversation

@sitammeur
Copy link
Contributor

This PR updates the Variational AutoEncoder Keras 3.0 example [TF Only Backend]. All TF ops are replaced with corresponding Keras ops.

For example, here is the notebook link provided:
https://colab.research.google.com/drive/15VMQIEUK8jhqI8nYyFSfWDh9fsMMs116?usp=sharing

cc: @fchollet

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/generative/vae.py b/examples/generative/vae.py
index d1d195c9..3396f4f7 100644
--- a/examples/generative/vae.py
+++ b/examples/generative/vae.py
@@ -18,6 +18,7 @@ os.environ["KERAS_BACKEND"] = "tensorflow"
 import numpy as np
 import tensorflow as tf
 import keras
+from keras import ops
 from keras import layers
 
 """
@@ -30,10 +31,10 @@ class Sampling(layers.Layer):
 
     def call(self, inputs):
         z_mean, z_log_var = inputs
-        batch = tf.shape(z_mean)[0]
-        dim = tf.shape(z_mean)[1]
-        epsilon = tf.random.normal(shape=(batch, dim))
-        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
+        batch = ops.shape(z_mean)[0]
+        dim = ops.shape(z_mean)[1]
+        epsilon = keras.random.normal(shape=(batch, dim))
+        return z_mean + ops.exp(0.5 * z_log_var) * epsilon
 
 
 """
@@ -94,14 +95,14 @@ class VAE(keras.Model):
         with tf.GradientTape() as tape:
             z_mean, z_log_var, z = self.encoder(data)
             reconstruction = self.decoder(z)
-            reconstruction_loss = tf.reduce_mean(
-                tf.reduce_sum(
+            reconstruction_loss = ops.mean(
+                ops.sum(
                     keras.losses.binary_crossentropy(data, reconstruction),
                     axis=(1, 2),
                 )
             )
-            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
-            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
+            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
+            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
             total_loss = reconstruction_loss + kl_loss
         grads = tape.gradient(total_loss, self.trainable_weights)
         self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
(END)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

return z_mean + tf.exp(0.5 * z_log_var) * epsilon
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(shape=(batch, dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be seeded via seed=self.seed_generator (to be created in the constructor, self.seed_generator = keras.random.SeedGenerator())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I apologize! I overlooked the RNG layers' statelessness. Updated in the most recent commit.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you -- please add the generated files.

@sitammeur
Copy link
Contributor Author

LGTM, thank you -- please add the generated files.

Absolutely! The generated files have been added.

@fchollet fchollet merged commit 190dbe5 into keras-team:master Apr 24, 2024
sitammeur added a commit to sitammeur/keras-io that referenced this pull request May 30, 2024
* vae keras 3 example updated

* seed generator added to rng layer

* generated files are added
@sitammeur sitammeur deleted the vae branch May 30, 2024 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants