diff --git a/LeReS/Train/lib/models/ILNR_loss.py b/LeReS/Train/lib/models/ILNR_loss.py index d4e5ff6..ebfa12e 100644 --- a/LeReS/Train/lib/models/ILNR_loss.py +++ b/LeReS/Train/lib/models/ILNR_loss.py @@ -20,7 +20,6 @@ def transform(self, gt): gt_i = gt[i] mask = gt_i > 0 depth_valid = gt_i[mask] - depth_valid = depth_valid[:5] if depth_valid.shape[0] < 10: data_mean.append(torch.tensor(0).cuda()) data_std_dev.append(torch.tensor(1).cuda()) @@ -49,7 +48,7 @@ def forward(self, pred, gt): pred_maskbatch = pred[mask_batch] gt_maskbatch = gt[mask_batch] - gt_mean, gt_std = self.transform(gt) + gt_mean, gt_std = self.transform(gt_maskbatch) gt_trans = (gt_maskbatch - gt_mean[:, None, None, None]) / (gt_std[:, None, None, None] + 1e-8) B, C, H, W = gt_maskbatch.shape diff --git a/LeReS/Train/lib/models/MSGIL_loss.py b/LeReS/Train/lib/models/MSGIL_loss.py index 9cc7395..53d747b 100644 --- a/LeReS/Train/lib/models/MSGIL_loss.py +++ b/LeReS/Train/lib/models/MSGIL_loss.py @@ -42,7 +42,6 @@ def transform(self, gt): gt_i = gt[i] mask = gt_i > 0 depth_valid = gt_i[mask] - depth_valid = depth_valid[:5] if depth_valid.shape[0] < 10: data_mean.append(torch.tensor(0).cuda()) data_std_dev.append(torch.tensor(1).cuda()) @@ -63,9 +62,10 @@ def forward(self, pred, gt): gt_mean, gt_std = self.transform(gt) gt_trans = (gt - gt_mean[:, None, None, None]) / (gt_std[:, None, None, None] + 1e-8) for i in range(self.scales_num): - d_gt = gt_trans[:, :, ::2, ::2] - d_pred = pred[:, :, ::2, ::2] - d_mask = mask[:, :, ::2, ::2] + step = pow(2, i) + d_gt = gt_trans[:, :, ::step, ::step] + d_pred = pred[:, :, ::step, ::step] + d_mask = mask[:, :, ::step, ::step] grad_term += self.one_scale_gradient_loss(d_pred, d_gt, d_mask) return grad_term diff --git a/LeReS/Train/lib/models/PWN_edges.py b/LeReS/Train/lib/models/PWN_edges.py index 93f30f1..a74eca8 100644 --- a/LeReS/Train/lib/models/PWN_edges.py +++ b/LeReS/Train/lib/models/PWN_edges.py @@ -241,7 +241,10 @@ def forward(self, pred_depths, gt_depths, images, focal_length): random_input_cos = torch.abs(torch.sum(random_inputs_A * random_inputs_B, dim=0)) loss += torch.sum(torch.abs(random_target_cos - random_input_cos)) / (random_target_cos.shape[0] + 1e-8) - return loss[0].float()/n + if loss[0] != 0: + return loss[0].float() / n + else: + return pred_depths.sum() * 0.0 diff --git a/LeReS/Train/lib/models/multi_depth_model_auxiv2.py b/LeReS/Train/lib/models/multi_depth_model_auxiv2.py index 37bb251..f7c7b4b 100644 --- a/LeReS/Train/lib/models/multi_depth_model_auxiv2.py +++ b/LeReS/Train/lib/models/multi_depth_model_auxiv2.py @@ -32,12 +32,8 @@ def inference(self, data): out = self.forward(data, is_train=False) pred_depth = out['decoder'] pred_disp = out['auxi'] - pred_depth_normalize = (pred_depth - pred_depth.min() + 1) / (pred_depth.max() - pred_depth.min()) #pred_depth - pred_depth.min() #- pred_depth.max() pred_depth_out = pred_depth - pred_disp_normalize = (pred_disp - pred_disp.min() + 1) / (pred_disp.max() - pred_disp.min()) - return {'pred_depth': pred_depth_out, 'pred_depth_normalize': pred_depth_normalize, - 'pred_disp': pred_disp, 'pred_disp_normalize': pred_disp_normalize, - } + return {'pred_depth': pred_depth_out, 'pred_disp': pred_disp} class ModelLoss(nn.Module): @@ -83,16 +79,20 @@ def auxi_loss(self, auxi, data): if 'disp' not in data: return {'total_loss': torch.tensor(0.0).cuda()} - gt = data['disp'].to(device=auxi.device) + gt_disp = data['disp'].to(device=auxi.device) + + mask_mid_quality = data['quality_flg'] >= 2 + gt_disp_mid = gt_disp[mask_mid_quality] + auxi_mid = auxi[mask_mid_quality] if '_ranking-edge-auxi_' in cfg.TRAIN.LOSS_MODE.lower(): - loss['ranking-edge_auxiloss'] = self.ranking_edge_auxiloss(auxi, gt, data['rgb']) + loss['ranking-edge_auxiloss'] = self.ranking_edge_auxiloss(auxi, gt_disp, data['rgb']) if '_msgil-normal-auxi_' in cfg.TRAIN.LOSS_MODE.lower(): - loss['msg_normal_auxiloss'] = (self.msg_normal_auxiloss(auxi, gt) * 0.5).float() + loss['msg_normal_auxiloss'] = (self.msg_normal_auxiloss(auxi_mid, gt_disp_mid) * 0.5).float() if '_meanstd-tanh-auxi_' in cfg.TRAIN.LOSS_MODE.lower(): - loss['meanstd-tanh_auxiloss'] = self.meanstd_tanh_auxiloss(auxi, gt) + loss['meanstd-tanh_auxiloss'] = self.meanstd_tanh_auxiloss(auxi_mid, gt_disp_mid) total_loss = sum(loss.values()) loss['total_loss'] = total_loss * cfg.TRAIN.LOSS_AUXI_WEIGHT @@ -157,7 +157,7 @@ def decoder_loss(self, pred_logit, data): # Multi-scale Gradient Loss if '_msgil-normal_' in cfg.TRAIN.LOSS_MODE.lower(): - loss['msg_normal_loss'] = (self.msg_normal_loss(pred_depth, gt_depth) * 0.1).float() + loss['msg_normal_loss'] = (self.msg_normal_loss(pred_depth_mid, gt_depth_mid) * 0.5).float() total_loss = sum(loss.values()) loss['total_loss'] = total_loss