Skip to content

Commit

Permalink
separated out pseudo-ground truth computation from classification loss
Browse files Browse the repository at this point in the history
  • Loading branch information
dakloepfer committed Apr 17, 2024
1 parent f8a0799 commit a6962be
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ We release implementations of the epipolar regression and the epipolar classific

In principle, these should be able to serve as drop-in replacements for regression and classification losses that use ground-truth correspondences, but depending on the model that is fine-tuned some adjustments to the input / output of the functions may need to be made.

Note: [ASpanFormer](https://aspanformer.github.io), which is one of the models we fine-tune in the paper, also uses a flow loss that requires (coarse) ground-truth correspondences. Similar to the classification loss, we replace those coarse correspondences with the location of the highest-confidence patch that is on the epipolar line as pseudo-ground truth correspondences. For a more convenient adaptation in this and similar cases, we provide a function to compute these pseudo-ground truth correspondences in the `epipolar_losses/epipolar_classification_loss.py` file.

### Scripts

#### Bootstrapping: Estimating Fundamental Matrices
Expand Down
65 changes: 50 additions & 15 deletions epipolar_losses/epipolar_classification_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import sqrt
from typing import Literal
from typing import Literal, Tuple

import torch
import torch.nn as nn
Expand All @@ -9,14 +9,12 @@
from kornia.utils import create_meshgrid


def epipolar_classification_loss(
def compute_pseudo_gt_corrs(
pred_conf: torch.Tensor,
fundamental_matrix: torch.Tensor,
epipolar_line_threshold: float = sqrt(2.0),
loss_type: Literal["cross_entropy"] = "cross_entropy",
reduction: Literal["mean", "sum", "none"] = "mean",
) -> torch.Tensor:
"""Compute the epipolar classification loss for a batch of predicted match confidences and fundamental matrices.
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the pseudo-ground truth correspondences by choosing the maximum confidence location along the epipolar line.
Parameters
----------
Expand All @@ -29,17 +27,15 @@ def epipolar_classification_loss(
epipolar_line_threshold (float):
The threshold for the distance of a pixel to the epipolar line to be considered to be part of the epipolar line. By default sqrt(2.0).
loss_type (str):
One of "cross_entropy" or "focal". Describes the type of classification loss to use. By default "cross_entropy".
reduction (str):
One of "mean", "sum", or "none". Describes how to reduce the loss across the batch, use "none" to get the loss for each match.
Returns
-------
classification_loss (tensor):
The epipolar classification loss, either of shape () if reduction is "mean" or "sum", or of shape (M,) if reduction is "none".
pseudo_gt (batch_size x height0 x width0 x height1 x width1 tensor):
The pseudo-ground truth correspondences, where for each pixel (patch) in image 0, the pixel (patch) in image 1 with the maximum confidence along the epipolar line is marked with a 1.
visible_epiline_mask (batch_size x height0 x width0 x 1 x 1 bool tensor):
A mask indicating whether the epipolar line for each pixel (patch) in image 0 is visible in image 1.
"""

b, h0, w0, h1, w1 = pred_conf.shape
dims = {"b": b, "h0": h0, "w0": w0, "h1": h1, "w1": w1}
device = pred_conf.device
Expand Down Expand Up @@ -86,13 +82,52 @@ def epipolar_classification_loss(
)
pseudo_gt = (temp_conf == max_conf_on_epilines).float()

visible_epiline_mask = epilines_mask.any(dim=(-1, -2), keepdim=True)

return pseudo_gt, visible_epiline_mask


def epipolar_classification_loss(
pred_conf: torch.Tensor,
fundamental_matrix: torch.Tensor,
epipolar_line_threshold: float = sqrt(2.0),
loss_type: Literal["cross_entropy"] = "cross_entropy",
reduction: Literal["mean", "sum", "none"] = "mean",
) -> torch.Tensor:
"""Compute the epipolar classification loss for a batch of predicted match confidences and fundamental matrices.
Parameters
----------
pred_conf (batch_size x height0 x width0 x height1 x width1 tensor):
For each pixel (patch) in image 0, this gives a confidence matrix for the probability that the respective pixel (patch) in image 1 is the matching pixel (patch). If any values are outside the range [0, 1], a dual-softmax will be applied (soft nearest-neighbours).
fundamental_matrix (batch_size x 3 x 3 tensor):
The fundamental matrix from 0 to 1 for the respective image pair.
epipolar_line_threshold (float):
The threshold for the distance of a pixel to the epipolar line to be considered to be part of the epipolar line. By default sqrt(2.0).
loss_type (str):
Only "cross_entropy" implemented at the moment. Describes the type of classification loss to use. By default "cross_entropy".
reduction (str):
One of "mean", "sum", or "none". Describes how to reduce the loss across the batch, use "none" to get the loss for each match.
Returns
-------
classification_loss (tensor):
The epipolar classification loss, either of shape () if reduction is "mean" or "sum", or of shape (M,) if reduction is "none".
"""
pseudo_gt, visible_epiline_mask = compute_pseudo_gt_corrs(
pred_conf, fundamental_matrix, epipolar_line_threshold
)

if loss_type == "cross_entropy":
loss = F.binary_cross_entropy(pred_conf, pseudo_gt, reduction="none")
else:
raise NotImplementedError("Unknown loss type: {}".format(loss_type))

# don't compute the loss if the epipolar line does not appear in image1
visible_epiline_mask = epilines_mask.any(dim=(-1, -2), keepdim=True)
loss = loss * visible_epiline_mask

if reduction == "none":
Expand Down

0 comments on commit a6962be

Please sign in to comment.