Skip to content

Commit

Permalink
Migrating Fully Convolutional Network example to Keras 3 (keras-team#…
Browse files Browse the repository at this point in the history
…1691)

* Migrated fully_convolutional_network example to Keras 3

* Initialised seed to integer values
  • Loading branch information
aditya02shah authored and SuryanarayanaY committed Jan 19, 2024
1 parent 82e512a commit 9428127
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
26 changes: 13 additions & 13 deletions examples/vision/fully_convolutional_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Image Segmentation using Composable Fully-Convolutional Networks
Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
Date created: 2023/06/16
Last modified: 2023/06/16
Last modified: 2023/12/25
Description: Using the Fully-Convolutional Network for Image Segmentation.
Accelerator: GPU
"""
Expand Down Expand Up @@ -48,15 +48,15 @@
## Setup Imports
"""

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np

keras.utils.set_random_seed(27)
tf.random.set_seed(27)

AUTOTUNE = tf.data.AUTOTUNE

"""
Expand Down Expand Up @@ -147,10 +147,10 @@ def unpack_resize_data(section):
# for Matplotlib visualization.

images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)

test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

# Overlay segmentation mask on top of image.
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
Expand Down Expand Up @@ -259,7 +259,7 @@ def preprocess_data(image, segmentation_mask):
activation="relu",
padding="same",
use_bias=False,
kernel_initializer=tf.constant_initializer(1.0),
kernel_initializer=keras.initializers.Constant(1.0),
)
dense_convs.append(dense_conv)
dropout_layer = keras.layers.Dropout(0.5)
Expand Down Expand Up @@ -556,13 +556,13 @@ def preprocess_data(image, segmentation_mask):
"""

images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE,seed=10)

# Get random test image and mask
test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

pred_image = tf.expand_dims(test_image, axis=0)
pred_image = ops.expand_dims(test_image, axis=0)
pred_image = keras.applications.vgg19.preprocess_input(pred_image)

# Perform inference on FCN-32S
Expand Down
26 changes: 13 additions & 13 deletions examples/vision/ipynb/fully_convolutional_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"\n",
"**Author:** [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)<br>\n",
"**Date created:** 2023/06/16<br>\n",
"**Last modified:** 2023/06/16<br>\n",
"**Last modified:** 2023/12/25<br>\n",
"**Description:** Using the Fully-Convolutional Network for Image Segmentation."
]
},
Expand Down Expand Up @@ -75,15 +75,15 @@
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"import keras\n",
"from keras import ops\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow_datasets as tfds\n",
"import numpy as np\n",
"\n",
"keras.utils.set_random_seed(27)\n",
"tf.random.set_seed(27)\n",
"\n",
"AUTOTUNE = tf.data.AUTOTUNE"
]
},
Expand Down Expand Up @@ -235,10 +235,10 @@
"# for Matplotlib visualization.\n",
"\n",
"images, masks = next(iter(test_ds))\n",
"random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)\n",
"random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)\n",
"\n",
"test_image = images[random_idx].numpy().astype(\"float\")\n",
"test_mask = masks[random_idx].numpy().astype(\"float\")\n",
"test_image = images[int(random_idx)].numpy().astype(\"float\")\n",
"test_mask = masks[int(random_idx)].numpy().astype(\"float\")\n",
"\n",
"# Overlay segmentation mask on top of image.\n",
"fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))\n",
Expand Down Expand Up @@ -384,7 +384,7 @@
" activation=\"relu\",\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" kernel_initializer=tf.constant_initializer(1.0),\n",
" kernel_initializer=keras.initializers.Constant(1.0),\n",
" )\n",
" dense_convs.append(dense_conv)\n",
" dropout_layer = keras.layers.Dropout(0.5)\n",
Expand Down Expand Up @@ -829,13 +829,13 @@
"outputs": [],
"source": [
"images, masks = next(iter(test_ds))\n",
"random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)\n",
"random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE,seed=10)\n",
"\n",
"# Get random test image and mask\n",
"test_image = images[random_idx].numpy().astype(\"float\")\n",
"test_mask = masks[random_idx].numpy().astype(\"float\")\n",
"test_image = images[int(random_idx)].numpy().astype(\"float\")\n",
"test_mask = masks[int(random_idx)].numpy().astype(\"float\")\n",
"\n",
"pred_image = tf.expand_dims(test_image, axis=0)\n",
"pred_image = ops.expand_dims(test_image, axis=0)\n",
"pred_image = keras.applications.vgg19.preprocess_input(pred_image)\n",
"\n",
"# Perform inference on FCN-32S\n",
Expand Down
26 changes: 13 additions & 13 deletions examples/vision/md/fully_convolutional_network.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)<br>
**Date created:** 2023/06/16<br>
**Last modified:** 2023/06/16<br>
**Last modified:** 2023/12/25<br>
**Description:** Using the Fully-Convolutional Network for Image Segmentation.


Expand Down Expand Up @@ -50,15 +50,15 @@ or a [PyImageSearch Blog on Semantic Segmentation](https://pyimagesearch.com/201


```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np

keras.utils.set_random_seed(27)
tf.random.set_seed(27)

AUTOTUNE = tf.data.AUTOTUNE
```

Expand Down Expand Up @@ -165,10 +165,10 @@ which makes the image and mask size same.
# for Matplotlib visualization.

images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)

test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

# Overlay segmentation mask on top of image.
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
Expand Down Expand Up @@ -286,7 +286,7 @@ for filter_idx in range(len(units)):
activation="relu",
padding="same",
use_bias=False,
kernel_initializer=tf.constant_initializer(1.0),
kernel_initializer=keras.initializers.Constant(1.0),,
)
dense_convs.append(dense_conv)
dropout_layer = keras.layers.Dropout(0.5)
Expand Down Expand Up @@ -972,13 +972,13 @@ Note: For better results, the model must be trained for a higher number of epoch

```python
images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)
random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE,seed=10)

# Get random test image and mask
test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")
test_image = images[int(random_idx)].numpy().astype("float")
test_mask = masks[int(random_idx)].numpy().astype("float")

pred_image = tf.expand_dims(test_image, axis=0)
pred_image = ops.expand_dims(test_image, axis=0)
pred_image = keras.applications.vgg19.preprocess_input(pred_image)

# Perform inference on FCN-32S
Expand Down

0 comments on commit 9428127

Please sign in to comment.