From 747d64bd5db9964b95fe86f1e7f1e581e5a02314 Mon Sep 17 00:00:00 2001 From: Ossan Tamago <52318961+ossan-tamago@users.noreply.github.com> Date: Wed, 6 Jul 2022 06:32:05 +0900 Subject: [PATCH] fix cyclegan save error (#925) * fix cyclegan save error * formatted * comment * remove * ipynb --- examples/generative/cyclegan.py | 13 ++++- examples/generative/ipynb/cyclegan.ipynb | 63 ++++++++++++------------ 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/examples/generative/cyclegan.py b/examples/generative/cyclegan.py index 97926ca6d6..5957230166 100644 --- a/examples/generative/cyclegan.py +++ b/examples/generative/cyclegan.py @@ -24,14 +24,13 @@ ## Setup """ + import os import numpy as np import matplotlib.pyplot as plt - import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers - import tensorflow_addons as tfa import tensorflow_datasets as tfds @@ -409,6 +408,14 @@ def __init__( self.lambda_cycle = lambda_cycle self.lambda_identity = lambda_identity + def call(self, inputs): + return ( + self.disc_X(inputs), + self.disc_Y(inputs), + self.gen_G(inputs), + self.gen_F(inputs), + ) + def compile( self, gen_G_optimizer, @@ -571,6 +578,8 @@ def on_epoch_end(self, epoch, logs=None): adv_loss_fn = keras.losses.MeanSquaredError() # Define the loss function for the generators + + def generator_loss_fn(fake): fake_loss = adv_loss_fn(tf.ones_like(fake), fake) return fake_loss diff --git a/examples/generative/ipynb/cyclegan.ipynb b/examples/generative/ipynb/cyclegan.ipynb index 50f0a36f14..a8cc7a157c 100644 --- a/examples/generative/ipynb/cyclegan.ipynb +++ b/examples/generative/ipynb/cyclegan.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -62,8 +62,7 @@ "import tensorflow_datasets as tfds\n", "\n", "tfds.disable_progress_bar()\n", - "autotune = tf.data.AUTOTUNE\n", - "" + "autotune = tf.data.AUTOTUNE\n" ] }, { @@ -81,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -127,8 +126,7 @@ " # Only resizing and normalization for the test images.\n", " img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])\n", " img = normalize_img(img)\n", - " return img\n", - "" + " return img\n" ] }, { @@ -142,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -175,8 +173,7 @@ " .cache()\n", " .shuffle(buffer_size)\n", " .batch(batch_size)\n", - ")\n", - "" + ")\n" ] }, { @@ -190,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -203,8 +200,7 @@ " zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)\n", " ax[i, 0].imshow(horse)\n", " ax[i, 1].imshow(zebra)\n", - "plt.show()\n", - "" + "plt.show()\n" ] }, { @@ -218,7 +214,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -337,8 +333,7 @@ " x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)\n", " if activation:\n", " x = activation(x)\n", - " return x\n", - "" + " return x\n" ] }, { @@ -375,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -418,8 +413,7 @@ " x = layers.Activation(\"tanh\")(x)\n", "\n", " model = keras.models.Model(img_input, x, name=name)\n", - " return model\n", - "" + " return model\n" ] }, { @@ -436,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -490,8 +484,7 @@ "\n", "# Get the discriminators\n", "disc_X = get_discriminator(name=\"discriminator_X\")\n", - "disc_Y = get_discriminator(name=\"discriminator_Y\")\n", - "" + "disc_Y = get_discriminator(name=\"discriminator_Y\")\n" ] }, { @@ -508,7 +501,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -533,6 +526,14 @@ " self.lambda_cycle = lambda_cycle\n", " self.lambda_identity = lambda_identity\n", "\n", + " def call(self, inputs):\n", + " return (\n", + " self.disc_X(inputs),\n", + " self.disc_Y(inputs),\n", + " self.gen_G(inputs),\n", + " self.gen_F(inputs),\n", + " )\n", + "\n", " def compile(\n", " self,\n", " gen_G_optimizer,\n", @@ -650,8 +651,7 @@ " \"F_loss\": total_loss_F,\n", " \"D_X_loss\": disc_X_loss,\n", " \"D_Y_loss\": disc_Y_loss,\n", - " }\n", - "" + " }\n" ] }, { @@ -665,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -697,8 +697,7 @@ " \"generated_img_{i}_{epoch}.png\".format(i=i, epoch=epoch + 1)\n", " )\n", " plt.show()\n", - " plt.close()\n", - "" + " plt.close()\n" ] }, { @@ -712,7 +711,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -773,13 +772,13 @@ "source": [ "Test the performance of the model.\n", "\n", - "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)", + "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/CycleGAN)\n", "and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/CycleGAN)." ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -793,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" }, @@ -805,7 +804,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab_type": "code" },