10
10
" \n " ,
11
11
" **Author:** [fchollet](https://twitter.com/fchollet)<br>\n " ,
12
12
" **Date created:** 2020/05/03<br>\n " ,
13
- " **Last modified:** 2023/11/22 <br>\n " ,
13
+ " **Last modified:** 2024/04/24 <br>\n " ,
14
14
" **Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits."
15
15
]
16
16
},
25
25
},
26
26
{
27
27
"cell_type" : " code" ,
28
- "execution_count" : 0 ,
28
+ "execution_count" : null ,
29
29
"metadata" : {
30
30
"colab_type" : " code"
31
31
},
38
38
" import numpy as np\n " ,
39
39
" import tensorflow as tf\n " ,
40
40
" import keras\n " ,
41
+ " from keras import ops\n " ,
41
42
" from keras import layers"
42
43
]
43
44
},
52
53
},
53
54
{
54
55
"cell_type" : " code" ,
55
- "execution_count" : 0 ,
56
+ "execution_count" : null ,
56
57
"metadata" : {
57
58
"colab_type" : " code"
58
59
},
62
63
" class Sampling(layers.Layer):\n " ,
63
64
" \"\"\" Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n " ,
64
65
" \n " ,
66
+ " def __init__(self, **kwargs):\n " ,
67
+ " super().__init__(**kwargs)\n " ,
68
+ " self.seed_generator = keras.random.SeedGenerator(1337)\n " ,
69
+ " \n " ,
65
70
" def call(self, inputs):\n " ,
66
71
" z_mean, z_log_var = inputs\n " ,
67
- " batch = tf.shape(z_mean)[0]\n " ,
68
- " dim = tf.shape(z_mean)[1]\n " ,
69
- " epsilon = tf.random.normal(shape=(batch, dim))\n " ,
70
- " return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n " ,
71
- " "
72
+ " batch = ops.shape(z_mean)[0]\n " ,
73
+ " dim = ops.shape(z_mean)[1]\n " ,
74
+ " epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n " ,
75
+ " return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n "
72
76
]
73
77
},
74
78
{
82
86
},
83
87
{
84
88
"cell_type" : " code" ,
85
- "execution_count" : 0 ,
89
+ "execution_count" : null ,
86
90
"metadata" : {
87
91
"colab_type" : " code"
88
92
},
113
117
},
114
118
{
115
119
"cell_type" : " code" ,
116
- "execution_count" : 0 ,
120
+ "execution_count" : null ,
117
121
"metadata" : {
118
122
"colab_type" : " code"
119
123
},
140
144
},
141
145
{
142
146
"cell_type" : " code" ,
143
- "execution_count" : 0 ,
147
+ "execution_count" : null ,
144
148
"metadata" : {
145
149
"colab_type" : " code"
146
150
},
170
174
" with tf.GradientTape() as tape:\n " ,
171
175
" z_mean, z_log_var, z = self.encoder(data)\n " ,
172
176
" reconstruction = self.decoder(z)\n " ,
173
- " reconstruction_loss = tf.reduce_mean (\n " ,
174
- " tf.reduce_sum (\n " ,
177
+ " reconstruction_loss = ops.mean (\n " ,
178
+ " ops.sum (\n " ,
175
179
" keras.losses.binary_crossentropy(data, reconstruction),\n " ,
176
180
" axis=(1, 2),\n " ,
177
181
" )\n " ,
178
182
" )\n " ,
179
- " kl_loss = -0.5 * (1 + z_log_var - tf .square(z_mean) - tf .exp(z_log_var))\n " ,
180
- " kl_loss = tf.reduce_mean(tf.reduce_sum (kl_loss, axis=1))\n " ,
183
+ " kl_loss = -0.5 * (1 + z_log_var - ops .square(z_mean) - ops .exp(z_log_var))\n " ,
184
+ " kl_loss = ops.mean(ops.sum (kl_loss, axis=1))\n " ,
181
185
" total_loss = reconstruction_loss + kl_loss\n " ,
182
186
" grads = tape.gradient(total_loss, self.trainable_weights)\n " ,
183
187
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n " ,
188
192
" \" loss\" : self.total_loss_tracker.result(),\n " ,
189
193
" \" reconstruction_loss\" : self.reconstruction_loss_tracker.result(),\n " ,
190
194
" \" kl_loss\" : self.kl_loss_tracker.result(),\n " ,
191
- " }\n " ,
192
- " "
195
+ " }\n "
193
196
]
194
197
},
195
198
{
203
206
},
204
207
{
205
208
"cell_type" : " code" ,
206
- "execution_count" : 0 ,
209
+ "execution_count" : null ,
207
210
"metadata" : {
208
211
"colab_type" : " code"
209
212
},
229
232
},
230
233
{
231
234
"cell_type" : " code" ,
232
- "execution_count" : 0 ,
235
+ "execution_count" : null ,
233
236
"metadata" : {
234
237
"colab_type" : " code"
235
238
},
286
289
},
287
290
{
288
291
"cell_type" : " code" ,
289
- "execution_count" : 0 ,
292
+ "execution_count" : null ,
290
293
"metadata" : {
291
294
"colab_type" : " code"
292
295
},
340
343
},
341
344
"nbformat" : 4 ,
342
345
"nbformat_minor" : 0
343
- }
346
+ }
0 commit comments