-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,532 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from mmengine.registry import Registry | ||
OPENOCC_LOSS = Registry('openocc_loss') | ||
|
||
from .multi_loss import MultiLoss | ||
from .rgb_loss_ms import RGBLossMS, SemLossMS, SemCELossMS | ||
from .reproj_loss_mono_multi_new import ReprojLossMonoMultiNew | ||
from .reproj_loss_mono_multi_new_combine import ReprojLossMonoMultiNewCombine | ||
from .edge_loss_3d_ms import EdgeLoss3DMS | ||
from .eikonal_loss import EikonalLoss | ||
from .sparsity_loss import SparsityLoss, HardSparsityLoss, SoftSparsityLoss, AdaptiveSparsityLoss | ||
from .second_grad_loss import SecondGradLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch.nn as nn | ||
from utils.tb_wrapper import WrappedTBWriter | ||
if 'selfocc' in WrappedTBWriter._instance_dict: | ||
writer = WrappedTBWriter.get_instance('selfocc') | ||
else: | ||
writer = None | ||
|
||
class BaseLoss(nn.Module): | ||
|
||
""" Base loss class. | ||
args: | ||
weight: weight of current loss. | ||
input_keys: keys for actual inputs to calculate_loss(). | ||
Since "inputs" may contain many different fields, we use input_keys | ||
to distinguish them. | ||
loss_func: the actual loss func to calculate loss. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
weight=1.0, | ||
input_dict={ | ||
'input': 'input'}, | ||
**kwargs): | ||
super().__init__() | ||
self.weight = weight | ||
self.input_dict = input_dict | ||
self.loss_func = lambda: 0 | ||
self.writer = writer | ||
|
||
# def calculate_loss(self, **kwargs): | ||
# return self.loss_func(*[kwargs[key] for key in self.input_keys]) | ||
|
||
def forward(self, inputs): | ||
actual_inputs = {} | ||
for input_key, input_val in self.input_dict.items(): | ||
actual_inputs.update({input_key: inputs[input_val]}) | ||
# return self.weight * self.calculate_loss(**actual_inputs) | ||
return self.weight * self.loss_func(**actual_inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
from .base_loss import BaseLoss | ||
from . import OPENOCC_LOSS | ||
import torch.nn.functional as F | ||
|
||
|
||
def get_smooth_loss(disp, img): | ||
"""Computes the smoothness loss for a disparity image | ||
The color image is used for edge-aware smoothness | ||
""" | ||
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) | ||
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) | ||
|
||
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) | ||
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) | ||
|
||
grad_disp_x *= torch.exp(-grad_img_x) | ||
grad_disp_y *= torch.exp(-grad_img_y) | ||
|
||
return grad_disp_x.mean() + grad_disp_y.mean() | ||
|
||
|
||
@OPENOCC_LOSS.register_module() | ||
class EdgeLoss3DMS(BaseLoss): | ||
|
||
def __init__(self, weight=1.0, input_dict=None, **kwargs): | ||
super().__init__(weight) | ||
|
||
if input_dict is None: | ||
self.input_dict = { | ||
'curr_imgs': 'curr_imgs', | ||
'ms_depths': 'ms_depths', | ||
'ms_rays': 'ms_rays' | ||
} | ||
else: | ||
self.input_dict = input_dict | ||
self.img_size = kwargs.get('img_size', [768, 1600]) | ||
self.ray_resize = kwargs.get('ray_resize', None) | ||
self.use_inf_mask = kwargs.get('use_inf_mask', False) | ||
# self.inf_dist = kwargs.get('inf_dist', 1e6) | ||
assert self.ray_resize is not None | ||
self.loss_func = self.edge_loss | ||
|
||
def edge_loss(self, curr_imgs, ms_depths, ms_rays, ms_accs=None, max_depths=None): | ||
# curr_imgs: B, N, C, H, W | ||
# depth: B, N, R | ||
# rays: R, 2 | ||
if self.use_inf_mask: | ||
assert ms_accs is not None and max_depths is not None | ||
if not isinstance(ms_rays, list): | ||
ms_rays = [ms_rays] * len(ms_depths) | ||
bs, num_cams, num_rays = ms_depths[0].shape | ||
|
||
tot_loss = 0. | ||
for scale, (depth, rays) in enumerate(zip(ms_depths, ms_rays)): | ||
pixel_curr = rays.clone().reshape(1, 1, num_rays, 2).repeat( | ||
bs * num_cams, 1, 1, 1) # bs*N, 1, R, 2 | ||
pixel_curr[..., 0] /= self.img_size[1] | ||
pixel_curr[..., 1] /= self.img_size[0] | ||
pixel_curr = pixel_curr * 2 - 1 | ||
rgb_curr = F.grid_sample( | ||
curr_imgs.flatten(0, 1), | ||
pixel_curr, | ||
mode='bilinear', | ||
padding_mode='border', | ||
align_corners=True) # bs*N, 3, 1, R | ||
rgb_curr = rgb_curr.reshape(bs * num_cams, -1, *self.ray_resize) | ||
|
||
if self.use_inf_mask: | ||
depth = depth * ms_accs[scale] + max_depths[scale] * (1 - ms_accs[scale]) | ||
|
||
depth = depth.reshape(bs * num_cams, 1, *self.ray_resize) | ||
mean_depth = depth.mean(2, True).mean(3, True) | ||
norm_depth = depth / (mean_depth + 1e-6) | ||
smooth_loss = get_smooth_loss(norm_depth, rgb_curr) | ||
|
||
tot_loss += smooth_loss | ||
|
||
return tot_loss / len(ms_depths) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from .base_loss import BaseLoss | ||
from . import OPENOCC_LOSS | ||
|
||
|
||
@OPENOCC_LOSS.register_module() | ||
class EikonalLoss(BaseLoss): | ||
|
||
def __init__(self, weight=1.0, input_dict=None, **kwargs): | ||
super().__init__(weight) | ||
|
||
if input_dict is None: | ||
self.input_dict = { | ||
'eik_grad': 'eik_grad', | ||
} | ||
else: | ||
self.input_dict = input_dict | ||
self.loss_func = self.eikonal_loss | ||
|
||
def eikonal_loss(self, eik_grad): | ||
grad_theta = eik_grad | ||
eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1) ** 2).mean() | ||
return eikonal_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch.nn as nn | ||
from . import OPENOCC_LOSS | ||
from utils.tb_wrapper import WrappedTBWriter | ||
if 'selfocc' in WrappedTBWriter._instance_dict: | ||
writer = WrappedTBWriter.get_instance('selfocc') | ||
else: | ||
writer = None | ||
|
||
@OPENOCC_LOSS.register_module() | ||
class MultiLoss(nn.Module): | ||
|
||
def __init__(self, loss_cfgs): | ||
super().__init__() | ||
|
||
assert isinstance(loss_cfgs, list) | ||
self.num_losses = len(loss_cfgs) | ||
|
||
losses = [] | ||
for loss_cfg in loss_cfgs: | ||
losses.append(OPENOCC_LOSS.build(loss_cfg)) | ||
self.losses = nn.ModuleList(losses) | ||
self.iter_counter = 0 | ||
|
||
def forward(self, inputs): | ||
|
||
loss_dict = {} | ||
tot_loss = 0. | ||
for loss_func in self.losses: | ||
loss = loss_func(inputs) | ||
tot_loss += loss | ||
loss_dict.update({ | ||
loss_func.__class__.__name__: \ | ||
loss.detach().item() | ||
}) | ||
if writer and self.iter_counter % 10 == 0: | ||
writer.add_scalar( | ||
f'loss/{loss_func.__class__.__name__}', | ||
loss.detach().item(), self.iter_counter) | ||
if writer and self.iter_counter % 10 == 0: | ||
writer.add_scalar( | ||
'loss/total', tot_loss.detach().item(), self.iter_counter) | ||
self.iter_counter += 1 | ||
|
||
return tot_loss, loss_dict |
Oops, something went wrong.