-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
269 additions
and
113 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
configs/deeplabv3p/deeplabv3p_resnet50_os8_cityscapes_1024x512_80k_rmiloss.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_: 'deeplabv3p_resnet50_os8_cityscapes_1024x512_80k.yml' | ||
|
||
loss: | ||
types: | ||
- type: MixedLoss | ||
losses: | ||
- type: CrossEntropyLoss | ||
- type: RMILoss | ||
coef: [0.5, 0.5] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""rmi loss in PaddlePaddle""" | ||
import numpy | ||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
from paddleseg.cvlibs import manager | ||
|
||
_euler_num = 2.718281828 | ||
_pi = 3.14159265 | ||
_ln_2_pi = 1.837877 | ||
_CLIP_MIN = 1e-6 | ||
_CLIP_MAX = 1.0 | ||
_POS_ALPHA = 5e-4 | ||
_IS_SUM = 1 | ||
|
||
|
||
@manager.LOSSES.add_component | ||
class RMILoss(nn.Layer): | ||
""" | ||
Implements the Region Mutual Information(RMI) Loss(https://arxiv.org/abs/1910.12037) for Semantic Segmentation. | ||
Unlike vanilla rmi loss which contains Cross Entropy Loss, we disband them and only | ||
left the RMI-related parts. | ||
The motivation is to allow for a more flexible combination of losses during training. | ||
For example, by employing mixed loss to merge RMI Loss with Boostrap Cross Entropy Loss, | ||
we can achieve the online mining of hard examples together with attention to region information. | ||
Args: | ||
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight | ||
given to each class. Its length must be equal to the number of classes. | ||
Default ``None``. | ||
ignore_index (int64, optional): Specifies a target value that is ignored | ||
and does not contribute to the input gradient. Default ``255``. | ||
""" | ||
|
||
def __init__(self, | ||
num_classes=19, | ||
rmi_radius=3, | ||
rmi_pool_way=0, | ||
rmi_pool_size=3, | ||
rmi_pool_stride=3, | ||
loss_weight_lambda=0.5, | ||
ignore_index=255): | ||
super(RMILoss, self).__init__() | ||
|
||
self.num_classes = num_classes | ||
assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
self.rmi_radius = rmi_radius | ||
assert rmi_pool_way in [0, 1, 2, 3] | ||
self.rmi_pool_way = rmi_pool_way | ||
assert rmi_pool_size == rmi_pool_stride | ||
self.rmi_pool_size = rmi_pool_size | ||
self.rmi_pool_stride = rmi_pool_stride | ||
self.weight_lambda = loss_weight_lambda | ||
self.half_d = self.rmi_radius * self.rmi_radius | ||
self.d = 2 * self.half_d | ||
self.kernel_padding = self.rmi_pool_size // 2 | ||
self.ignore_index = ignore_index | ||
|
||
def forward(self, logits_4D, labels_4D, do_rmi=True): | ||
""" | ||
Forward computation. | ||
Args: | ||
logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty). | ||
labels (Tensor): Shape is [N, H, W], ground truth labels (between 0 and C - 1). | ||
""" | ||
logits_4D = paddle.cast(logits_4D, dtype='float32') | ||
labels_4D = paddle.cast(labels_4D, dtype='float32') | ||
|
||
loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi) | ||
return loss | ||
|
||
def forward_sigmoid(self, logits_4D, labels_4D, do_rmi=False): | ||
""" | ||
Using the sigmiod operation both. | ||
Args: | ||
logits_4D : [N, C, H, W], dtype=float32 | ||
labels_4D : [N, H, W], dtype=long | ||
do_rmi : bool | ||
""" | ||
label_mask_3D = labels_4D != self.ignore_index | ||
valid_onehot_labels_4D = paddle.cast( | ||
F.one_hot( | ||
paddle.cast(labels_4D, dtype='int64') * paddle.cast( | ||
label_mask_3D, dtype='int64'), | ||
num_classes=self.num_classes), | ||
dtype='float32') | ||
# label_mask_flat = paddle.cast( | ||
# paddle.reshape(label_mask_3D, [-1]), dtype='float32') | ||
|
||
valid_onehot_labels_4D = valid_onehot_labels_4D * paddle.unsqueeze( | ||
label_mask_3D, axis=3) | ||
valid_onehot_labels_4D.stop_gradient = True | ||
probs_4D = F.sigmoid(logits_4D) * paddle.unsqueeze( | ||
label_mask_3D, axis=1) + _CLIP_MIN | ||
|
||
valid_onehot_labels_4D = paddle.transpose(valid_onehot_labels_4D, | ||
[0, 3, 1, 2]) | ||
valid_onehot_labels_4D.stop_gradient = True | ||
rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) | ||
|
||
return rmi_loss | ||
|
||
def inverse(self, x): | ||
return paddle.inverse(x) | ||
|
||
def rmi_lower_bound(self, labels_4D, probs_4D): | ||
""" | ||
calculate the lower bound of the region mutual information. | ||
Args: | ||
labels_4D : [N, C, H, W], dtype=float32 | ||
probs_4D : [N, C, H, W], dtype=float32 | ||
""" | ||
assert labels_4D.shape == probs_4D.shape, print( | ||
'shapes', labels_4D.shape, probs_4D.shape) | ||
|
||
p, s = self.rmi_pool_size, self.rmi_pool_stride | ||
if self.rmi_pool_stride > 1: | ||
if self.rmi_pool_way == 0: | ||
labels_4D = F.max_pool2d( | ||
labels_4D, | ||
kernel_size=p, | ||
stride=s, | ||
padding=self.kernel_padding) | ||
probs_4D = F.max_pool2d( | ||
probs_4D, | ||
kernel_size=p, | ||
stride=s, | ||
padding=self.kernel_padding) | ||
elif self.rmi_pool_way == 1: | ||
labels_4D = F.avg_pool2d( | ||
labels_4D, | ||
kernel_size=p, | ||
stride=s, | ||
padding=self.kernel_padding) | ||
probs_4D = F.avg_pool2d( | ||
probs_4D, | ||
kernel_size=p, | ||
stride=s, | ||
padding=self.kernel_padding) | ||
elif self.rmi_pool_way == 2: | ||
shape = labels_4D.shape | ||
new_h, new_w = shape[2] // s, shape[3] // s | ||
labels_4D = F.interpolate( | ||
labels_4D, size=(new_h, new_w), mode='nearest') | ||
probs_4D = F.interpolate( | ||
probs_4D, | ||
size=(new_h, new_w), | ||
mode='bilinear', | ||
align_corners=True) | ||
else: | ||
raise NotImplementedError("Pool way of RMI is not defined!") | ||
|
||
label_shape = labels_4D.shape | ||
n, c = label_shape[0], label_shape[1] | ||
|
||
la_vectors, pr_vectors = self.map_get_pairs( | ||
labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) | ||
|
||
la_vectors = paddle.reshape(la_vectors, [n, c, self.half_d, -1]) | ||
la_vectors = paddle.cast(la_vectors, dtype='float64') | ||
la_vectors.stop_gradient = True | ||
|
||
pr_vectors = paddle.reshape(pr_vectors, [n, c, self.half_d, -1]) | ||
pr_vectors = paddle.cast(pr_vectors, dtype='float64') | ||
|
||
diag_matrix = paddle.unsqueeze( | ||
paddle.unsqueeze(paddle.eye(self.half_d), axis=0), axis=0) | ||
la_vectors = la_vectors - paddle.mean(la_vectors, axis=3, keepdim=True) | ||
|
||
la_cov = paddle.matmul(la_vectors, | ||
paddle.transpose(la_vectors, [0, 1, 3, 2])) | ||
pr_vectors = pr_vectors - paddle.mean(pr_vectors, axis=3, keepdim=True) | ||
pr_cov = paddle.matmul(pr_vectors, | ||
paddle.transpose(pr_vectors, [0, 1, 3, 2])) | ||
|
||
pr_cov_inv = self.inverse( | ||
pr_cov + paddle.cast(diag_matrix, dtype='float64') * _POS_ALPHA) | ||
|
||
la_pr_cov = paddle.matmul(la_vectors, | ||
paddle.transpose(pr_vectors, [0, 1, 3, 2])) | ||
|
||
appro_var = la_cov - paddle.matmul( | ||
paddle.matmul(la_pr_cov, pr_cov_inv), | ||
paddle.transpose(la_pr_cov, [0, 1, 3, 2])) | ||
|
||
rmi_now = 0.5 * self.log_det_by_cholesky( | ||
appro_var + paddle.cast(diag_matrix, dtype='float64') * _POS_ALPHA) | ||
|
||
rmi_per_class = paddle.cast( | ||
paddle.mean( | ||
paddle.reshape(rmi_now, [-1, self.num_classes]), axis=0), | ||
dtype='float32') | ||
rmi_per_class = paddle.divide(rmi_per_class, | ||
paddle.to_tensor(float(self.half_d))) | ||
|
||
rmi_loss = paddle.sum(rmi_per_class) if _IS_SUM else paddle.mean( | ||
rmi_per_class) | ||
|
||
return rmi_loss | ||
|
||
def log_det_by_cholesky(self, matrix): | ||
""" | ||
Args: | ||
matrix: matrix must be a positive define matrix. | ||
shape [N, C, D, D]. | ||
""" | ||
|
||
chol = paddle.cholesky(matrix) | ||
diag = paddle.diagonal(chol, offset=0, axis1=-2, axis2=-1) | ||
chol = paddle.log(diag + 1e-8) | ||
|
||
return 2.0 * paddle.sum(chol, axis=-1) | ||
|
||
def map_get_pairs(self, labels_4D, probs_4D, radius=3, is_combine=True): | ||
""" | ||
Args: | ||
labels_4D : labels, shape [N, C, H, W] | ||
probs_4D : probabilities, shape [N, C, H, W] | ||
radius : the square radius | ||
Return: | ||
tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] | ||
""" | ||
|
||
label_shape = labels_4D.shape | ||
h, w = label_shape[2], label_shape[3] | ||
new_h, new_w = h - (radius - 1), w - (radius - 1) | ||
la_ns = [] | ||
pr_ns = [] | ||
for y in range(0, radius, 1): | ||
for x in range(0, radius, 1): | ||
la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] | ||
pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] | ||
la_ns.append(la_now) | ||
pr_ns.append(pr_now) | ||
|
||
if is_combine: | ||
pair_ns = la_ns + pr_ns | ||
p_vectors = paddle.stack(pair_ns, axis=2) | ||
return p_vectors | ||
else: | ||
la_vectors = paddle.stack(la_ns, axis=2) | ||
pr_vectors = paddle.stack(pr_ns, axis=2) | ||
return la_vectors, pr_vectors |