Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update robust loss #502

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions theseus/core/robust_cost_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
cost_function: CostFunction,
loss_cls: Type[RobustLoss],
log_loss_radius: Variable,
RobustSum: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep python naming conventions using lower_case for variables instead of RobustSum. Not sure what a good name for this should be; let's start with flatten_dims: bool = False, which defaults to the old version, and True means the new version.

name: Optional[str] = None,
):
self.cost_function = cost_function
Expand All @@ -70,6 +71,7 @@ def __init__(
self.log_loss_radius = log_loss_radius
self.register_aux_var("log_loss_radius")
self.loss = loss_cls()
self.RobustSum = RobustSum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.flatten_dims = flatten_dims


def error(self) -> torch.Tensor:
warnings.warn(
Expand All @@ -80,20 +82,33 @@ def error(self) -> torch.Tensor:

def weighted_error(self) -> torch.Tensor:
weighted_error = self.cost_function.weighted_error()
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True)
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor)


if self.RobustSum:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done much more simply. If I'm not mistaken, you can keep all of the old code and just add the following lines at the beginning of weighted_error()

if self.flatten_dims:
    weighted_error = weighted_error.view(-1, 1)

and then reshape the output back to (-1, self.dim()) before returning.

A similar procedure can be followed for weighted_jacobians.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean. And why I need to reshape it. Because what I want is given any shape of error, like (Batch, Dim) and the corresponding radius with the shape of (Batch, 1), they will all be passed to the evaluation function self.loss.evaluate and return the scaled loss.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to explain, but it works, I just tested it. I'll put another PR soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at #503. The code for the jacobians part was a bit tricker than I suggested above, but I was still able to mostly use all of the old computation. Also look at the unit test I added and let me know if covers the expected behavior. Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hard to explain, but it works, I just tested it. I'll put another PR soon.

Oh! I see what you mean! You treat every element of the original cost as the original "square norm"

squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True)
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor)
return (
torch.ones_like(weighted_error)
* (error_loss / self.dim() + RobustCostFunction._EPS).sqrt()
)
else:
squared_norm = weighted_error**2
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor)
return (
(error_loss + RobustCostFunction._EPS).sqrt()
)
# Inside will compare squared_norm with exp(radius)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove any debug comments.


#print(weighted_error, error_loss)
#print("scaled", error_loss)
#print(self.dim())
# The return value is a hacky way to make it so that
# ||weighted_error||^2 = error_loss
# By doing this we avoid having to change the objective's error computation
# specifically for robust cost functions. The issue for this type of cost
# function is that the theory requires us to maintain scaled errors/jacobians
# of dim = robust_fn.cost_function.dim() to do the linearization properly,
# but the actual error has dim = 1, being the result of loss(||error||^2).
return (
torch.ones_like(weighted_error)
* (error_loss / self.dim() + RobustCostFunction._EPS).sqrt()
)


def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
warnings.warn(
Expand All @@ -107,15 +122,43 @@ def weighted_jacobians_error(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
weighted_jacobians,
weighted_error,
) = self.cost_function.weighted_jacobians_error()
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True)
rescale = (
self.loss.linearize(squared_norm, self.log_loss_radius.tensor)
+ RobustCostFunction._EPS
).sqrt()

return [
rescale.view(-1, 1, 1) * jacobian for jacobian in weighted_jacobians
], rescale * weighted_error
#squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above on a suggestion for implementing this more easily.

#print(len(weighted_jacobians), weighted_error.shape, weighted_error.dim(), self.cost_function.dim())

# I do not check the linearization part. I assume they should be correct

#print(rescale.shape)
# The rescale should be reshape as the jocobian shape such that the multipilication works and fulfill the chain rule


if self.RobustSum:

squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True)
rescale = (
self.loss.linearize(squared_norm, self.log_loss_radius.tensor)
+ RobustCostFunction._EPS
).sqrt()

return [
rescale.view(-1, 1, 1) * jacobian for jacobian in weighted_jacobians
], rescale * weighted_error

else:
squared_norm = weighted_error**2
rescale = (
self.loss.linearize(squared_norm, self.log_loss_radius.tensor)
+ RobustCostFunction._EPS
).sqrt()

for jacobian in weighted_jacobians:
rescale_tile = rescale.unsqueeze(2)
rescale_tile = torch.tile(rescale_tile, (1, 1, jacobian.shape[2]))

return [
rescale_tile * jacobian for jacobian in weighted_jacobians
], rescale * weighted_error



def dim(self) -> int:
return self.cost_function.dim()
Expand Down Expand Up @@ -143,4 +186,4 @@ def _supports_masking(self) -> bool:
@_supports_masking.setter
def _supports_masking(self, val: bool):
self.cost_function._supports_masking = val
self.__supports_masking__ = val
self.__supports_masking__ = val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing end line at end of file.

9 changes: 5 additions & 4 deletions theseus/core/robust_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@


class RobustLoss(abc.ABC):
# .suqre() make it easier to compare with residual.square(), which is x
@classmethod
def evaluate(cls, x: torch.Tensor, log_radius: torch.Tensor) -> torch.Tensor:
return cls._evaluate_impl(x, log_radius.exp())

return cls._evaluate_impl(x, log_radius.square())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this as exp().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Could I know the initiative behind it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since square() is not monotonous, I personally prefer 'exp()', which is more useful when we try to learn a robust loss radius.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! That's great! Because at a very early time, I tried to let the system automatically learn the Radius, and I failed at that time (performance was not good). But with the new version, the training outcome may change.

@classmethod
def linearize(cls, x: torch.Tensor, log_radius: torch.Tensor) -> torch.Tensor:
return cls._linearize_impl(x, log_radius.exp())
return cls._linearize_impl(x, log_radius.square())

@staticmethod
@abc.abstractmethod
Expand Down Expand Up @@ -49,4 +50,4 @@ def _evaluate_impl(x: torch.Tensor, radius: torch.Tensor) -> torch.Tensor:

@staticmethod
def _linearize_impl(x: torch.Tensor, radius: torch.Tensor) -> torch.Tensor:
return torch.sqrt(radius / torch.max(x, radius) + _LOSS_EPS)
return torch.sqrt(radius / torch.max(x, radius) + _LOSS_EPS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing line at end of file.

Loading