Skip to content

Commit

Permalink
✨ Add topographic loss to Generator Network's loss function
Browse files Browse the repository at this point in the history
Implement a 'topographic loss' that tries to ensure the predicted high resolution DeepBedMap DEM's is topographically similar to the low resolution BEDMAP2 DEM. Currently hardcoded to work on 4x upsampling only, and I've removed some old references in the GeneratorModel class that had a 'scaling' setting for other upsampling factors (e.g. 2, 6, 8, etc) which was never implemented. Also quickly patching 75b7493, as the YUML diagram did not mention the bilinear resampling on W2...
  • Loading branch information
weiji14 committed Aug 20, 2019
1 parent 1b717a0 commit d599ee8
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 116 deletions.
89 changes: 29 additions & 60 deletions deepbedmap.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions deepbedmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_deepbedmap_model_inputs(
print(X_tile.shape, W1_tile.shape, W2_tile.shape, W3_tile.shape)

# Build quilt package for datasets covering our test region
reupload = True
reupload = False
if reupload == True:
quilt.build(package="weiji14/deepbedmap/model/test/W1_tile", path=W1_tile)
quilt.build(package="weiji14/deepbedmap/model/test/W2_tile", path=W2_tile)
Expand Down Expand Up @@ -334,7 +334,7 @@ def plot_3d_view(

# %%
def load_trained_model(
experiment_key: str = "cf156ecbac43467fbb014d1964041066", # or simply use "latest"
experiment_key: str = "ac0e3aba2c6a457fa0a717bb4844621e", # or simply use "latest"
model_weights_path: str = "model/weights/srgan_generator_model_weights.npz",
):
"""
Expand Down
110 changes: 76 additions & 34 deletions srgan_train.ipynb

Large diffs are not rendered by default.

74 changes: 57 additions & 17 deletions srgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,15 @@ def forward(self, x):
# %% [markdown]
# ### 2.1.3 Build the Generator Network, with upsampling layers!
#
# ![4 inputs feeding into the Generator Network, producing a high resolution prediction output](https://yuml.me/bea381ef.png)
# ![4 inputs feeding into the Generator Network, producing a high resolution prediction output](https://yuml.me/df9919c6.png)
#
# <!--[W3_input(ACCUMULATION)|1x9x9]-k4n32s1>[W3_inter|32x8x8],[W3_inter]->[Concat|128x8x8]
# [W2_input(MEASURES)|2x20x20]-k8n32s1>[W2_inter|32x8x8],[W2_inter]->[Concat|128x8x8]
# <!--
# [W3_input(ACCUMULATION)|1x9x9]-k4n32s1>[W3_inter|32x8x8],[W3_inter]->[Concat|128x8x8]
# [W2_input(MEASURES)|2x20x20]-bilinear_resize>[W2_resized|2x18x18],[W2_resized]-k8n32s2>[W2_inter|32x8x8],[W2_inter]->[Concat|128x8x8]
# [W1_input(REMA)|1x90x90]-k40n32s10>[W1_inter|32x8x8],[W1_inter]->[Concat|128x8x8]
# [X_input(BEDMAP2)|1x9x9]-k4n32s1>[X_inter|32x8x8],[X_inter]->[Concat|128x8x8]
# [Concat|8x8x128]->[Generator-Network|Many-Residual-Blocks],[Generator-Network]->[Y_hat(High-Resolution_DEM)|1x36x36]-->
# [Concat|8x8x128]->[Generator-Network|Many-Residual-Blocks],[Generator-Network]->[Y_hat(High-Resolution_DEM)|1x36x36]
# -->

# %%
class GeneratorModel(chainer.Chain):
Expand All @@ -434,9 +436,8 @@ class GeneratorModel(chainer.Chain):
Glues the input block with several residual blocks and upsampling layers
Parameters:
input_shape -- shape of input tensor in tuple format (height, width, channels)
num_residual_blocks -- how many Conv-LeakyReLU-Conv blocks to use
scaling -- even numbered integer to increase resolution (e.g. 0, 2, 4, 6, 8)
num_residual_blocks -- how many Residual-in-Residual Dense Blocks to use
residual_scaling -- scale factor for residuals before adding to parent branch
out_channels -- integer representing number of output channels/filters/kernels
Example:
Expand Down Expand Up @@ -566,7 +567,7 @@ def forward(
a3 = F.add(a1, a3)

# 4th part
# Upsampling (if 4; run twice, if 8; run thrice, etc.)
# Upsampling (hardcoded to be 4x, actually 2x run twice)
# Convert shape from 8x8 to 9x9 using Convolution2D k2n64s1
a4_0 = self.pre_upsample_conv_layer(a3)
# Uses Nearest Neighbour Interpolation followed by Convolution2D k3n64s1
Expand Down Expand Up @@ -723,13 +724,15 @@ def forward(self, x: cupy.ndarray):
#
# Now we define the Perceptual Loss function for our Generator and Discriminator neural network models, where:
#
# $$Perceptual Loss = Content Loss + Adversarial Loss$$
# $$Perceptual Loss = Content Loss + Adversarial Loss + Topographic Loss$$
#
# ![Perceptual Loss in an Enhanced Super Resolution Generative Adversarial Network](https://yuml.me/db58d683.png)
# ![Perceptual Loss in an adapted Enhanced Super Resolution Generative Adversarial Network](https://yuml.me/7731ae34.png)
#
# <!--
# [LowRes-Inputs]-Generator>[SuperResolution_DEM]
# [SuperResolution_DEM]-.->[note:Content-Loss|MeanAbsoluteError{bg:yellow}]
# [LowRes-Inputs]-.->[note:Topographic-Loss|MeanAbsoluteError{bg:yellow}]
# [SuperResolution_DEM]-.->[note:Topographic-Loss]
# [HighRes-Groundtruth_DEM]-.->[note:Content-Loss]
# [SuperResolution_DEM]-Discriminator>[False_or_True_Prediction]
# [HighRes-Groundtruth_DEM]-Discriminator>[False_or_True_Prediction]
Expand All @@ -738,13 +741,14 @@ def forward(self, x: cupy.ndarray):
# [False_or_True_Label]-.->[note:Adversarial-Loss]
# [note:Content-Loss]-.->[note:Perceptual-Loss{bg:gold}]
# [note:Adversarial-Loss]-.->[note:Perceptual-Loss{bg:gold}]
# [note:Topographic-Loss]-.->[note:Perceptual-Loss{bg:gold}]
# -->

# %% [markdown]
# ### Content Loss
#
# The original SRGAN paper by [Ledig et al. 2017](https://arxiv.org/abs/1609.04802v5) calculates *Content Loss* based on the ReLU activation layers of the pre-trained 19 layer VGG network.
# The implementation below is less advanced, simply using an L1 loss, i.e., a pixel-wise [Mean Absolute Error (MAE) loss](https://keras.io/losses/#mean_absolute_error) as the *Content Loss*.
# The implementation below is less advanced, simply using an L1 loss, i.e., a pixel-wise [Mean Absolute Error (MAE) loss](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.mean_absolute_error.html) as the *Content Loss*.
# Specifically, the *Content Loss* is calculated as the MAE difference between the output of the generator model (i.e. the predicted Super Resolution Image) and that of the groundtruth image (i.e. the true High Resolution Image).
#
# $$ e_i = ||G(x_{i}) - y_i||_{1} $$
Expand Down Expand Up @@ -809,6 +813,28 @@ def forward(self, x: cupy.ndarray):
#
# See also how [Pytorch](https://pytorch.org/docs/stable/nn.html?highlight=bcewithlogitsloss#torch.nn.BCEWithLogitsLoss) and [Tensorflow](https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits) implements this in a numerically stable manner.

# %% [markdown]
# ### Topographic Loss
#
# In addition to the L1 Content Loss, we further define a *Topographic Loss*.
# Specifically, we want each of the averaged value in each 4x4 grid of the predicted DeepBedMap image to correspond to a 1x1 pixel on the BEDMAP2 image.
#
# Due to BEDMAP2 having a 4x lower resolution than the predicted DeepBedMap DEM (1000m compared to 250m), we first apply a 4x4 [Mean/Average Pooling](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.average_pooling_2d.html) operation on the DeepBedMap image, turning it into a 1x1 pixel grid that matches the shape of the BEDMAP2 image.
#
# $$ \bar{y_j} = Mean = \dfrac{1}{n} \sum\limits_{i=1}^n y_i $$
#
# where $\bar{y_j}$ is the mean/average of all predicted pixel values $y_i$ across the 16 $i$ DeepBedMap pixels within a 4x4 grid corresponding to the spatial location of one (BEDMAP2) pixel at position $j$.
# Then, we calculate the [MAE](https://docs.chainer.org/en/latest/reference/generated/chainer.functions.mean_absolute_error.html) difference between the output of the generator model (i.e. the predicted Super-Resolution DeepBedMap bed DEM Image) and that of the BEDMAP2 (i.e. the original Low Resolution Image we are super-resolving).
#
# $$ e_j = ||\bar{y_j} - x_j||_{1} $$
#
# where $\bar{y_j}$ is the mean of the 4x4 pixels we just calculated above, and $x_j$ is the spatially corresponding BEDMAP2 pixel, respectively at BEDMAP2 pixel $j$.
# $e_j$ thus represents the absolute error (L1 loss) (denoted by $||\dots||_{1}$) between the (averaged) super-resolution and low-resolution values.
# We then sum all the pixel-wise errors $e_j,\dots,e_n$ and divide by the number of pixels $n$ to get the Arithmetic Mean $\dfrac{1}{n} \sum\limits_{i=1}^n$ of our error which is our *Topographic Loss*.
#
# $$ Loss_{Topographic} = Mean Absolute Error = \dfrac{1}{n} \sum\limits_{i=1}^n e_j $$
#

# %%
def calculate_generator_loss(
y_pred: chainer.variable.Variable,
Expand All @@ -817,25 +843,28 @@ def calculate_generator_loss(
real_labels: cupy.ndarray,
fake_minus_real_target: cupy.ndarray,
real_minus_fake_target: cupy.ndarray,
x_topo: cupy.ndarray,
content_loss_weighting: float = 1e-2,
adversarial_loss_weighting: float = 5e-3,
topographic_loss_weighting: float = 5e-3,
) -> chainer.variable.Variable:
"""
This function calculates the weighted sum between
"Content Loss" and "Adversarial Loss".
"Content Loss", "Adversarial Loss", and "Topographic Loss"
which forms the basis for training the Generator Network.
>>> calculate_generator_loss(
... y_pred=chainer.variable.Variable(data=np.ones(shape=(2, 1, 3, 3))),
... y_true=np.full(shape=(2, 1, 3, 3), fill_value=10.0),
... y_pred=chainer.variable.Variable(data=np.ones(shape=(2, 1, 12, 12))),
... y_true=np.full(shape=(2, 1, 12, 12), fill_value=10.0),
... fake_labels=np.array([[-1.2], [0.5]]),
... real_labels=np.array([[0.5], [-0.8]]),
... fake_minus_real_target=np.array([[1], [1]]).astype(np.int32),
... real_minus_fake_target=np.array([[0], [0]]).astype(np.int32),
... x_topo=np.full(shape=(2, 1, 3, 3), fill_value=9.0),
... )
variable(0.09867307)
variable(0.13867307)
"""
# Content Loss (L1, Mean Absolute Error) between 2D images
# Content Loss (L1, Mean Absolute Error) between predicted and groundtruth 2D images
content_loss = F.mean_absolute_error(x0=y_pred, x1=y_true)

# Adversarial Loss between 1D labels
Expand All @@ -846,10 +875,19 @@ def calculate_generator_loss(
fake_minus_real_target=fake_minus_real_target, # Ones (1) instead of zeros (0)
)

# Topographic Loss (L1, Mean Absolute Error) between predicted and low res 2D images
topographic_loss = F.mean_absolute_error(
x0=F.average_pooling_2d(x=y_pred, ksize=(4, 4)), x1=x_topo
)

# Get generator loss
weighted_content_loss = content_loss_weighting * content_loss
weighted_adversarial_loss = adversarial_loss_weighting * adversarial_loss
g_loss = weighted_content_loss + weighted_adversarial_loss
weighted_topographic_loss = topographic_loss_weighting * topographic_loss

g_loss = (
weighted_content_loss + weighted_adversarial_loss + weighted_topographic_loss
)

return g_loss

Expand Down Expand Up @@ -1167,6 +1205,8 @@ def train_eval_generator(
real_labels=real_labels, # real label 'should' get close to 0
fake_minus_real_target=fake_minus_real_target, # where 1 (fake) - 0 (real) = 1 (target)
real_minus_fake_target=real_minus_fake_target, # where 0 (real) - 1 (fake) = 0 (target)?
# topographic loss inputs, 2D image of low resolution input
x_topo=input_arrays["X"],
)
g_psnr = psnr(y_pred=fake_images.array, y_true=real_images)

Expand Down
7 changes: 4 additions & 3 deletions test_ipynb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,16 @@
"ok\n",
"Trying:\n",
" calculate_generator_loss(\n",
" y_pred=chainer.variable.Variable(data=np.ones(shape=(2, 1, 3, 3))),\n",
" y_true=np.full(shape=(2, 1, 3, 3), fill_value=10.0),\n",
" y_pred=chainer.variable.Variable(data=np.ones(shape=(2, 1, 12, 12))),\n",
" y_true=np.full(shape=(2, 1, 12, 12), fill_value=10.0),\n",
" fake_labels=np.array([[-1.2], [0.5]]),\n",
" real_labels=np.array([[0.5], [-0.8]]),\n",
" fake_minus_real_target=np.array([[1], [1]]).astype(np.int32),\n",
" real_minus_fake_target=np.array([[0], [0]]).astype(np.int32),\n",
" x_topo=np.full(shape=(2, 1, 3, 3), fill_value=9.0),\n",
" )\n",
"Expecting:\n",
" variable(0.09867307)\n",
" variable(0.13867307)\n",
"ok\n",
"Trying:\n",
" psnr(\n",
Expand Down

0 comments on commit d599ee8

Please sign in to comment.