Skip to content

Commit

Permalink
fix bugs of training.
Browse files Browse the repository at this point in the history
  • Loading branch information
guangkaixu committed Jul 20, 2022
1 parent c5370f1 commit 2d4bf66
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
3 changes: 1 addition & 2 deletions LeReS/Train/lib/models/ILNR_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def transform(self, gt):
gt_i = gt[i]
mask = gt_i > 0
depth_valid = gt_i[mask]
depth_valid = depth_valid[:5]
if depth_valid.shape[0] < 10:
data_mean.append(torch.tensor(0).cuda())
data_std_dev.append(torch.tensor(1).cuda())
Expand Down Expand Up @@ -49,7 +48,7 @@ def forward(self, pred, gt):
pred_maskbatch = pred[mask_batch]
gt_maskbatch = gt[mask_batch]

gt_mean, gt_std = self.transform(gt)
gt_mean, gt_std = self.transform(gt_maskbatch)
gt_trans = (gt_maskbatch - gt_mean[:, None, None, None]) / (gt_std[:, None, None, None] + 1e-8)

B, C, H, W = gt_maskbatch.shape
Expand Down
8 changes: 4 additions & 4 deletions LeReS/Train/lib/models/MSGIL_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def transform(self, gt):
gt_i = gt[i]
mask = gt_i > 0
depth_valid = gt_i[mask]
depth_valid = depth_valid[:5]
if depth_valid.shape[0] < 10:
data_mean.append(torch.tensor(0).cuda())
data_std_dev.append(torch.tensor(1).cuda())
Expand All @@ -63,9 +62,10 @@ def forward(self, pred, gt):
gt_mean, gt_std = self.transform(gt)
gt_trans = (gt - gt_mean[:, None, None, None]) / (gt_std[:, None, None, None] + 1e-8)
for i in range(self.scales_num):
d_gt = gt_trans[:, :, ::2, ::2]
d_pred = pred[:, :, ::2, ::2]
d_mask = mask[:, :, ::2, ::2]
step = pow(2, i)
d_gt = gt_trans[:, :, ::step, ::step]
d_pred = pred[:, :, ::step, ::step]
d_mask = mask[:, :, ::step, ::step]
grad_term += self.one_scale_gradient_loss(d_pred, d_gt, d_mask)
return grad_term

Expand Down
5 changes: 4 additions & 1 deletion LeReS/Train/lib/models/PWN_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ def forward(self, pred_depths, gt_depths, images, focal_length):
random_input_cos = torch.abs(torch.sum(random_inputs_A * random_inputs_B, dim=0))
loss += torch.sum(torch.abs(random_target_cos - random_input_cos)) / (random_target_cos.shape[0] + 1e-8)

return loss[0].float()/n
if loss[0] != 0:
return loss[0].float() / n
else:
return pred_depths.sum() * 0.0



Expand Down
20 changes: 10 additions & 10 deletions LeReS/Train/lib/models/multi_depth_model_auxiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ def inference(self, data):
out = self.forward(data, is_train=False)
pred_depth = out['decoder']
pred_disp = out['auxi']
pred_depth_normalize = (pred_depth - pred_depth.min() + 1) / (pred_depth.max() - pred_depth.min()) #pred_depth - pred_depth.min() #- pred_depth.max()
pred_depth_out = pred_depth
pred_disp_normalize = (pred_disp - pred_disp.min() + 1) / (pred_disp.max() - pred_disp.min())
return {'pred_depth': pred_depth_out, 'pred_depth_normalize': pred_depth_normalize,
'pred_disp': pred_disp, 'pred_disp_normalize': pred_disp_normalize,
}
return {'pred_depth': pred_depth_out, 'pred_disp': pred_disp}


class ModelLoss(nn.Module):
Expand Down Expand Up @@ -83,16 +79,20 @@ def auxi_loss(self, auxi, data):
if 'disp' not in data:
return {'total_loss': torch.tensor(0.0).cuda()}

gt = data['disp'].to(device=auxi.device)
gt_disp = data['disp'].to(device=auxi.device)

mask_mid_quality = data['quality_flg'] >= 2
gt_disp_mid = gt_disp[mask_mid_quality]
auxi_mid = auxi[mask_mid_quality]

if '_ranking-edge-auxi_' in cfg.TRAIN.LOSS_MODE.lower():
loss['ranking-edge_auxiloss'] = self.ranking_edge_auxiloss(auxi, gt, data['rgb'])
loss['ranking-edge_auxiloss'] = self.ranking_edge_auxiloss(auxi, gt_disp, data['rgb'])

if '_msgil-normal-auxi_' in cfg.TRAIN.LOSS_MODE.lower():
loss['msg_normal_auxiloss'] = (self.msg_normal_auxiloss(auxi, gt) * 0.5).float()
loss['msg_normal_auxiloss'] = (self.msg_normal_auxiloss(auxi_mid, gt_disp_mid) * 0.5).float()

if '_meanstd-tanh-auxi_' in cfg.TRAIN.LOSS_MODE.lower():
loss['meanstd-tanh_auxiloss'] = self.meanstd_tanh_auxiloss(auxi, gt)
loss['meanstd-tanh_auxiloss'] = self.meanstd_tanh_auxiloss(auxi_mid, gt_disp_mid)

total_loss = sum(loss.values())
loss['total_loss'] = total_loss * cfg.TRAIN.LOSS_AUXI_WEIGHT
Expand Down Expand Up @@ -157,7 +157,7 @@ def decoder_loss(self, pred_logit, data):

# Multi-scale Gradient Loss
if '_msgil-normal_' in cfg.TRAIN.LOSS_MODE.lower():
loss['msg_normal_loss'] = (self.msg_normal_loss(pred_depth, gt_depth) * 0.1).float()
loss['msg_normal_loss'] = (self.msg_normal_loss(pred_depth_mid, gt_depth_mid) * 0.5).float()

total_loss = sum(loss.values())
loss['total_loss'] = total_loss
Expand Down

0 comments on commit 2d4bf66

Please sign in to comment.