Skip to content

Commit

Permalink
fix cyclegan save error (keras-team#925)
Browse files Browse the repository at this point in the history
* fix cyclegan save error

* formatted

* comment

* remove

* ipynb
  • Loading branch information
k2ok3i authored Jul 5, 2022
1 parent f10a038 commit 747d64b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
13 changes: 11 additions & 2 deletions examples/generative/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
## Setup
"""


import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import tensorflow_addons as tfa
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -409,6 +408,14 @@ def __init__(
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity

def call(self, inputs):
return (
self.disc_X(inputs),
self.disc_Y(inputs),
self.gen_G(inputs),
self.gen_F(inputs),
)

def compile(
self,
gen_G_optimizer,
Expand Down Expand Up @@ -571,6 +578,8 @@ def on_epoch_end(self, epoch, logs=None):
adv_loss_fn = keras.losses.MeanSquaredError()

# Define the loss function for the generators


def generator_loss_fn(fake):
fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
return fake_loss
Expand Down
63 changes: 31 additions & 32 deletions examples/generative/ipynb/cyclegan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -62,8 +62,7 @@
"import tensorflow_datasets as tfds\n",
"\n",
"tfds.disable_progress_bar()\n",
"autotune = tf.data.AUTOTUNE\n",
""
"autotune = tf.data.AUTOTUNE\n"
]
},
{
Expand All @@ -81,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -127,8 +126,7 @@
" # Only resizing and normalization for the test images.\n",
" img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])\n",
" img = normalize_img(img)\n",
" return img\n",
""
" return img\n"
]
},
{
Expand All @@ -142,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -175,8 +173,7 @@
" .cache()\n",
" .shuffle(buffer_size)\n",
" .batch(batch_size)\n",
")\n",
""
")\n"
]
},
{
Expand All @@ -190,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -203,8 +200,7 @@
" zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)\n",
" ax[i, 0].imshow(horse)\n",
" ax[i, 1].imshow(zebra)\n",
"plt.show()\n",
""
"plt.show()\n"
]
},
{
Expand All @@ -218,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -337,8 +333,7 @@
" x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n",
" if activation:\n",
" x = activation(x)\n",
" return x\n",
""
" return x\n"
]
},
{
Expand Down Expand Up @@ -375,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -418,8 +413,7 @@
" x = layers.Activation(\"tanh\")(x)\n",
"\n",
" model = keras.models.Model(img_input, x, name=name)\n",
" return model\n",
""
" return model\n"
]
},
{
Expand All @@ -436,7 +430,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -490,8 +484,7 @@
"\n",
"# Get the discriminators\n",
"disc_X = get_discriminator(name=\"discriminator_X\")\n",
"disc_Y = get_discriminator(name=\"discriminator_Y\")\n",
""
"disc_Y = get_discriminator(name=\"discriminator_Y\")\n"
]
},
{
Expand All @@ -508,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -533,6 +526,14 @@
" self.lambda_cycle = lambda_cycle\n",
" self.lambda_identity = lambda_identity\n",
"\n",
" def call(self, inputs):\n",
" return (\n",
" self.disc_X(inputs),\n",
" self.disc_Y(inputs),\n",
" self.gen_G(inputs),\n",
" self.gen_F(inputs),\n",
" )\n",
"\n",
" def compile(\n",
" self,\n",
" gen_G_optimizer,\n",
Expand Down Expand Up @@ -650,8 +651,7 @@
" \"F_loss\": total_loss_F,\n",
" \"D_X_loss\": disc_X_loss,\n",
" \"D_Y_loss\": disc_Y_loss,\n",
" }\n",
""
" }\n"
]
},
{
Expand All @@ -665,7 +665,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -697,8 +697,7 @@
" \"generated_img_{i}_{epoch}.png\".format(i=i, epoch=epoch + 1)\n",
" )\n",
" plt.show()\n",
" plt.close()\n",
""
" plt.close()\n"
]
},
{
Expand All @@ -712,7 +711,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -773,13 +772,13 @@
"source": [
"Test the performance of the model.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN)."
]
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -793,7 +792,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand All @@ -805,7 +804,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab_type": "code"
},
Expand Down

0 comments on commit 747d64b

Please sign in to comment.