Skip to content

Commit

Permalink
Monocular depth estimation - Keras 3 Migration (Only Tensorflow Backe…
Browse files Browse the repository at this point in the history
…nd) (keras-team#1910)

* Monocular depth estimation - Keras 3 Migration (Only Tensorflow Backend)

* trim output

* Added PyDataset
  • Loading branch information
chunduriv authored Aug 31, 2024
1 parent 4301f03 commit ff969fe
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 165 deletions.
93 changes: 62 additions & 31 deletions examples/vision/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Monocular depth estimation
Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
Date created: 2021/08/30
Last modified: 2021/08/30
Last modified: 2024/08/13
Description: Implement a depth estimation model with a convnet.
Accelerator: GPU
"""
Expand All @@ -25,17 +25,21 @@
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import sys

import tensorflow as tf
from tensorflow.keras import layers

import keras
from keras import layers
from keras import ops
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

tf.random.set_seed(123)
keras.utils.set_random_seed(123)

"""
## Downloading the dataset
Expand All @@ -52,7 +56,7 @@

annotation_folder = "/dataset/"
if not os.path.exists(os.path.abspath(".") + annotation_folder):
annotation_zip = tf.keras.utils.get_file(
annotation_zip = keras.utils.get_file(
"val.tar.gz",
cache_subdir=os.path.abspath("."),
origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
Expand Down Expand Up @@ -89,7 +93,7 @@

HEIGHT = 256
WIDTH = 256
LR = 0.0002
LR = 0.00001
EPOCHS = 30
BATCH_SIZE = 32

Expand All @@ -105,8 +109,9 @@
"""


class DataGenerator(tf.keras.utils.Sequence):
class DataGenerator(keras.utils.PyDataset):
def __init__(self, data, batch_size=6, dim=(768, 1024), n_channels=3, shuffle=True):
super().__init__()
"""
Initialization
"""
Expand Down Expand Up @@ -178,7 +183,7 @@ def data_generation(self, batch):
self.data["depth"][batch_id],
self.data["mask"][batch_id],
)

x, y = x.astype("float32"), y.astype("float32")
return x, y


Expand Down Expand Up @@ -249,10 +254,10 @@ def __init__(
super().__init__(**kwargs)
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.bn2a = tf.keras.layers.BatchNormalization()
self.bn2b = tf.keras.layers.BatchNormalization()
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)
self.bn2a = layers.BatchNormalization()
self.bn2b = layers.BatchNormalization()

self.pool = layers.MaxPool2D((2, 2), (2, 2))

Expand All @@ -278,10 +283,10 @@ def __init__(
self.us = layers.UpSampling2D((2, 2))
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.bn2a = tf.keras.layers.BatchNormalization()
self.bn2b = tf.keras.layers.BatchNormalization()
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)
self.bn2a = layers.BatchNormalization()
self.bn2b = layers.BatchNormalization()
self.conc = layers.Concatenate()

def call(self, x, skip):
Expand All @@ -305,8 +310,8 @@ def __init__(
super().__init__(**kwargs)
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)

def call(self, x):
x = self.convA(x)
Expand All @@ -328,13 +333,39 @@ def call(self, x):
"""


class DepthEstimationModel(tf.keras.Model):
def image_gradients(image):
if len(ops.shape(image)) != 4:
raise ValueError(
"image_gradients expects a 4D tensor "
"[batch_size, h, w, d], not {}.".format(ops.shape(image))
)

image_shape = ops.shape(image)
batch_size, height, width, depth = ops.unstack(image_shape)

dy = image[:, 1:, :, :] - image[:, :-1, :, :]
dx = image[:, :, 1:, :] - image[:, :, :-1, :]

# Return tensors with same size as original image by concatenating
# zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y).
shape = ops.stack([batch_size, 1, width, depth])
dy = ops.concatenate([dy, ops.zeros(shape, dtype=image.dtype)], axis=1)
dy = ops.reshape(dy, image_shape)

shape = ops.stack([batch_size, height, 1, depth])
dx = ops.concatenate([dx, ops.zeros(shape, dtype=image.dtype)], axis=2)
dx = ops.reshape(dx, image_shape)

return dy, dx


class DepthEstimationModel(keras.Model):
def __init__(self):
super().__init__()
self.ssim_loss_weight = 0.85
self.l1_loss_weight = 0.1
self.edge_loss_weight = 0.9
self.loss_metric = tf.keras.metrics.Mean(name="loss")
self.loss_metric = keras.metrics.Mean(name="loss")
f = [16, 32, 64, 128, 256]
self.downscale_blocks = [
DownscaleBlock(f[0]),
Expand All @@ -353,28 +384,28 @@ def __init__(self):

def calculate_loss(self, target, pred):
# Edges
dy_true, dx_true = tf.image.image_gradients(target)
dy_pred, dx_pred = tf.image.image_gradients(pred)
weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))
dy_true, dx_true = image_gradients(target)
dy_pred, dx_pred = image_gradients(pred)
weights_x = ops.cast(ops.exp(ops.mean(ops.abs(dx_true))), "float32")
weights_y = ops.cast(ops.exp(ops.mean(ops.abs(dy_true))), "float32")

# Depth smoothness
smoothness_x = dx_pred * weights_x
smoothness_y = dy_pred * weights_y

depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
depth_smoothness_loss = ops.mean(abs(smoothness_x)) + ops.mean(
abs(smoothness_y)
)

# Structural similarity (SSIM) index
ssim_loss = tf.reduce_mean(
ssim_loss = ops.mean(
1
- tf.image.ssim(
target, pred, max_val=WIDTH, filter_size=7, k1=0.01**2, k2=0.03**2
)
)
# Point-wise depth
l1_loss = tf.reduce_mean(tf.abs(target - pred))
l1_loss = ops.mean(ops.abs(target - pred))

loss = (
(self.ssim_loss_weight * ssim_loss)
Expand Down Expand Up @@ -432,9 +463,9 @@ def call(self, x):
## Model training
"""

optimizer = tf.keras.optimizers.Adam(
optimizer = keras.optimizers.SGD(
learning_rate=LR,
amsgrad=False,
nesterov=False,
)
model = DepthEstimationModel()
# Compile the model
Expand Down Expand Up @@ -491,9 +522,9 @@ def call(self, x):
## References
The following papers go deeper into possible approaches for depth estimation.
1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/pdf/1811.06152v1.pdf)
1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/abs/1811.06152v1)
2. [Digging Into Self-Supervised Monocular Depth Estimation](https://openaccess.thecvf.com/content_ICCV_2019/papers/Godard_Digging_Into_Self-Supervised_Monocular_Depth_Estimation_ICCV_2019_paper.pdf)
3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/pdf/1606.00373v2.pdf)
3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/abs/1606.00373v2)
You can also find helpful implementations in the papers with code depth estimation task.
Expand Down
Binary file modified examples/vision/img/depth_estimation/depth_estimation_13_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/depth_estimation/depth_estimation_15_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ff969fe

Please sign in to comment.