Skip to content

Commit

Permalink
update latent walk example (#1668)
Browse files Browse the repository at this point in the history
* update latent walk example

* update plot grid

---------

Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli>
  • Loading branch information
divyashreepathihalli authored Dec 8, 2023
1 parent 6d8b0dc commit 36e6e00
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 328 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
119 changes: 60 additions & 59 deletions examples/generative/ipynb/random_walks_with_stable_diffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"!pip install tensorflow keras_cv --upgrade --quiet"
"!pip install keras-cv --upgrade --quiet"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import keras_cv\n",
"from tensorflow import keras\n",
"import keras\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"from keras import ops\n",
"import numpy as np\n",
"import math\n",
"from PIL import Image\n",
Expand Down Expand Up @@ -125,7 +127,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -135,10 +137,10 @@
"prompt_2 = \"A still life DSLR photo of a bowl of fruit\"\n",
"interpolation_steps = 5\n",
"\n",
"encoding_1 = tf.squeeze(model.encode_text(prompt_1))\n",
"encoding_2 = tf.squeeze(model.encode_text(prompt_2))\n",
"encoding_1 = ops.squeeze(model.encode_text(prompt_1))\n",
"encoding_2 = ops.squeeze(model.encode_text(prompt_2))\n",
"\n",
"interpolated_encodings = tf.linspace(encoding_1, encoding_2, interpolation_steps)\n",
"interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)\n",
"\n",
"# Show the size of the latent manifold\n",
"print(f\"Encoding shape: {encoding_1.shape}\")"
Expand All @@ -157,14 +159,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"seed = 12345\n",
"noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)\n",
"noise = keras.random.normal((512 // 8, 512 // 8, 4), seed=seed)\n",
"\n",
"images = model.generate_image(\n",
" interpolated_encodings,\n",
Expand Down Expand Up @@ -196,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -245,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -255,8 +257,8 @@
"batch_size = 3\n",
"batches = interpolation_steps // batch_size\n",
"\n",
"interpolated_encodings = tf.linspace(encoding_1, encoding_2, interpolation_steps)\n",
"batched_encodings = tf.split(interpolated_encodings, batches)\n",
"interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)\n",
"batched_encodings = ops.split(interpolated_encodings, batches)\n",
"\n",
"images = []\n",
"for batch in range(batches):\n",
Expand Down Expand Up @@ -290,7 +292,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -305,20 +307,20 @@
"batch_size = 3\n",
"batches = (interpolation_steps**2) // batch_size\n",
"\n",
"encoding_1 = tf.squeeze(model.encode_text(prompt_1))\n",
"encoding_2 = tf.squeeze(model.encode_text(prompt_2))\n",
"encoding_3 = tf.squeeze(model.encode_text(prompt_3))\n",
"encoding_4 = tf.squeeze(model.encode_text(prompt_4))\n",
"encoding_1 = ops.squeeze(model.encode_text(prompt_1))\n",
"encoding_2 = ops.squeeze(model.encode_text(prompt_2))\n",
"encoding_3 = ops.squeeze(model.encode_text(prompt_3))\n",
"encoding_4 = ops.squeeze(model.encode_text(prompt_4))\n",
"\n",
"interpolated_encodings = tf.linspace(\n",
" tf.linspace(encoding_1, encoding_2, interpolation_steps),\n",
" tf.linspace(encoding_3, encoding_4, interpolation_steps),\n",
"interpolated_encodings = ops.linspace(\n",
" ops.linspace(encoding_1, encoding_2, interpolation_steps),\n",
" ops.linspace(encoding_3, encoding_4, interpolation_steps),\n",
" interpolation_steps,\n",
")\n",
"interpolated_encodings = tf.reshape(\n",
"interpolated_encodings = ops.reshape(\n",
" interpolated_encodings, (interpolation_steps**2, 77, 768)\n",
")\n",
"batched_encodings = tf.split(interpolated_encodings, batches)\n",
"batched_encodings = ops.split(interpolated_encodings, batches)\n",
"\n",
"images = []\n",
"for batch in range(batches):\n",
Expand All @@ -337,19 +339,23 @@
" grid_size,\n",
" scale=2,\n",
"):\n",
" fig = plt.figure(figsize=(grid_size * scale, grid_size * scale))\n",
" fig, axs = plt.subplots(\n",
" grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)\n",
" )\n",
" fig.tight_layout()\n",
" plt.subplots_adjust(wspace=0, hspace=0)\n",
" plt.margins(x=0, y=0)\n",
" plt.axis(\"off\")\n",
" for ax in axs.flat:\n",
" ax.axis(\"off\")\n",
" images = images.astype(int)\n",
" for row in range(grid_size):\n",
" for col in range(grid_size):\n",
" index = row * grid_size + col\n",
" plt.subplot(grid_size, grid_size, index + 1)\n",
" plt.imshow(images[index].astype(\"uint8\"))\n",
" plt.axis(\"off\")\n",
" plt.margins(x=0, y=0)\n",
" for i in range(min(grid_size * grid_size, len(images))):\n",
" ax = axs.flat[i]\n",
" ax.imshow(images[i].astype(\"uint8\"))\n",
" ax.axis(\"off\")\n",
" for i in range(len(images), grid_size * grid_size):\n",
" axs.flat[i].axis(\"off\")\n",
" axs.flat[i].remove()\n",
" plt.savefig(\n",
" fname=path,\n",
" pad_inches=0,\n",
Expand All @@ -375,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -405,7 +411,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -416,18 +422,18 @@
"batches = walk_steps // batch_size\n",
"step_size = 0.005\n",
"\n",
"encoding = tf.squeeze(\n",
"encoding = ops.squeeze(\n",
" model.encode_text(\"The Eiffel Tower in the style of starry night\")\n",
")\n",
"# Note that (77, 768) is the shape of the text encoding.\n",
"delta = tf.ones_like(encoding) * step_size\n",
"delta = ops.ones_like(encoding) * step_size\n",
"\n",
"walked_encodings = []\n",
"for step_index in range(walk_steps):\n",
" walked_encodings.append(encoding)\n",
" encoding += delta\n",
"walked_encodings = tf.stack(walked_encodings)\n",
"batched_encodings = tf.split(walked_encodings, batches)\n",
"walked_encodings = ops.stack(walked_encodings)\n",
"batched_encodings = ops.split(walked_encodings, batches)\n",
"\n",
"images = []\n",
"for batch in range(batches):\n",
Expand Down Expand Up @@ -464,35 +470,35 @@
"that the diffusion model can produce from that prompt. We do this by controlling\n",
"the noise that is used to seed the diffusion process.\n",
"\n",
"We create two noise components, `x` and `y`, and do a walk from 0 to , summing\n",
"We create two noise components, `x` and `y`, and do a walk from 0 to 2\u03c0, summing\n",
"the cosine of our `x` component and the sin of our `y` component to produce noise.\n",
"Using this approach, the end of our walk arrives at the same noise inputs where\n",
"we began our walk, so we get a \"loopable\" result!"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"prompt = \"An oil paintings of cows in a field next to a windmill in Holland\"\n",
"encoding = tf.squeeze(model.encode_text(prompt))\n",
"encoding = ops.squeeze(model.encode_text(prompt))\n",
"walk_steps = 150\n",
"batch_size = 3\n",
"batches = walk_steps // batch_size\n",
"\n",
"walk_noise_x = tf.random.normal(noise.shape, dtype=tf.float64)\n",
"walk_noise_y = tf.random.normal(noise.shape, dtype=tf.float64)\n",
"walk_noise_x = keras.random.normal(noise.shape, dtype=\"float64\")\n",
"walk_noise_y = keras.random.normal(noise.shape, dtype=\"float64\")\n",
"\n",
"walk_scale_x = tf.cos(tf.linspace(0, 2, walk_steps) * math.pi)\n",
"walk_scale_y = tf.sin(tf.linspace(0, 2, walk_steps) * math.pi)\n",
"noise_x = tf.tensordot(walk_scale_x, walk_noise_x, axes=0)\n",
"noise_y = tf.tensordot(walk_scale_y, walk_noise_y, axes=0)\n",
"noise = tf.add(noise_x, noise_y)\n",
"batched_noise = tf.split(noise, batches)\n",
"walk_scale_x = ops.cos(ops.linspace(0, 2, walk_steps) * math.pi)\n",
"walk_scale_y = ops.sin(ops.linspace(0, 2, walk_steps) * math.pi)\n",
"noise_x = ops.tensordot(walk_scale_x, walk_noise_x, axes=0)\n",
"noise_y = ops.tensordot(walk_scale_y, walk_noise_y, axes=0)\n",
"noise = ops.add(noise_x, noise_y)\n",
"batched_noise = ops.split(noise, batches)\n",
"\n",
"images = []\n",
"for batch in range(batches):\n",
Expand Down Expand Up @@ -539,7 +545,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3.10.7 64-bit",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -553,14 +559,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
}
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 36e6e00

Please sign in to comment.