-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update robust loss #502
Update robust loss #502
Changes from all commits
f7b1e7e
594ecbf
a7a1173
335363c
3095ea9
c787ad4
c78a18f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ def __init__( | |
cost_function: CostFunction, | ||
loss_cls: Type[RobustLoss], | ||
log_loss_radius: Variable, | ||
RobustSum: bool = False, | ||
name: Optional[str] = None, | ||
): | ||
self.cost_function = cost_function | ||
|
@@ -70,6 +71,7 @@ def __init__( | |
self.log_loss_radius = log_loss_radius | ||
self.register_aux_var("log_loss_radius") | ||
self.loss = loss_cls() | ||
self.RobustSum = RobustSum | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def error(self) -> torch.Tensor: | ||
warnings.warn( | ||
|
@@ -80,20 +82,33 @@ def error(self) -> torch.Tensor: | |
|
||
def weighted_error(self) -> torch.Tensor: | ||
weighted_error = self.cost_function.weighted_error() | ||
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True) | ||
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor) | ||
|
||
|
||
if self.RobustSum: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be done much more simply. If I'm not mistaken, you can keep all of the old code and just add the following lines at the beginning of if self.flatten_dims:
weighted_error = weighted_error.view(-1, 1) and then reshape the output back to A similar procedure can be followed for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit hard to explain, but it works, I just tested it. I'll put another PR soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please take a look at #503. The code for the jacobians part was a bit tricker than I suggested above, but I was still able to mostly use all of the old computation. Also look at the unit test I added and let me know if covers the expected behavior. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh! I see what you mean! You treat every element of the original cost as the original "square norm" |
||
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True) | ||
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor) | ||
return ( | ||
torch.ones_like(weighted_error) | ||
* (error_loss / self.dim() + RobustCostFunction._EPS).sqrt() | ||
) | ||
else: | ||
squared_norm = weighted_error**2 | ||
error_loss = self.loss.evaluate(squared_norm, self.log_loss_radius.tensor) | ||
return ( | ||
(error_loss + RobustCostFunction._EPS).sqrt() | ||
) | ||
# Inside will compare squared_norm with exp(radius) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove any debug comments. |
||
|
||
#print(weighted_error, error_loss) | ||
#print("scaled", error_loss) | ||
#print(self.dim()) | ||
# The return value is a hacky way to make it so that | ||
# ||weighted_error||^2 = error_loss | ||
# By doing this we avoid having to change the objective's error computation | ||
# specifically for robust cost functions. The issue for this type of cost | ||
# function is that the theory requires us to maintain scaled errors/jacobians | ||
# of dim = robust_fn.cost_function.dim() to do the linearization properly, | ||
# but the actual error has dim = 1, being the result of loss(||error||^2). | ||
return ( | ||
torch.ones_like(weighted_error) | ||
* (error_loss / self.dim() + RobustCostFunction._EPS).sqrt() | ||
) | ||
|
||
|
||
def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: | ||
warnings.warn( | ||
|
@@ -107,15 +122,43 @@ def weighted_jacobians_error(self) -> Tuple[List[torch.Tensor], torch.Tensor]: | |
weighted_jacobians, | ||
weighted_error, | ||
) = self.cost_function.weighted_jacobians_error() | ||
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True) | ||
rescale = ( | ||
self.loss.linearize(squared_norm, self.log_loss_radius.tensor) | ||
+ RobustCostFunction._EPS | ||
).sqrt() | ||
|
||
return [ | ||
rescale.view(-1, 1, 1) * jacobian for jacobian in weighted_jacobians | ||
], rescale * weighted_error | ||
#squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment above on a suggestion for implementing this more easily. |
||
#print(len(weighted_jacobians), weighted_error.shape, weighted_error.dim(), self.cost_function.dim()) | ||
|
||
# I do not check the linearization part. I assume they should be correct | ||
|
||
#print(rescale.shape) | ||
# The rescale should be reshape as the jocobian shape such that the multipilication works and fulfill the chain rule | ||
|
||
|
||
if self.RobustSum: | ||
|
||
squared_norm = torch.sum(weighted_error**2, dim=1, keepdim=True) | ||
rescale = ( | ||
self.loss.linearize(squared_norm, self.log_loss_radius.tensor) | ||
+ RobustCostFunction._EPS | ||
).sqrt() | ||
|
||
return [ | ||
rescale.view(-1, 1, 1) * jacobian for jacobian in weighted_jacobians | ||
], rescale * weighted_error | ||
|
||
else: | ||
squared_norm = weighted_error**2 | ||
rescale = ( | ||
self.loss.linearize(squared_norm, self.log_loss_radius.tensor) | ||
+ RobustCostFunction._EPS | ||
).sqrt() | ||
|
||
for jacobian in weighted_jacobians: | ||
rescale_tile = rescale.unsqueeze(2) | ||
rescale_tile = torch.tile(rescale_tile, (1, 1, jacobian.shape[2])) | ||
|
||
return [ | ||
rescale_tile * jacobian for jacobian in weighted_jacobians | ||
], rescale * weighted_error | ||
|
||
|
||
|
||
def dim(self) -> int: | ||
return self.cost_function.dim() | ||
|
@@ -143,4 +186,4 @@ def _supports_masking(self) -> bool: | |
@_supports_masking.setter | ||
def _supports_masking(self, val: bool): | ||
self.cost_function._supports_masking = val | ||
self.__supports_masking__ = val | ||
self.__supports_masking__ = val | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing end line at end of file. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,13 +11,14 @@ | |
|
||
|
||
class RobustLoss(abc.ABC): | ||
# .suqre() make it easier to compare with residual.square(), which is x | ||
@classmethod | ||
def evaluate(cls, x: torch.Tensor, log_radius: torch.Tensor) -> torch.Tensor: | ||
return cls._evaluate_impl(x, log_radius.exp()) | ||
|
||
return cls._evaluate_impl(x, log_radius.square()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should keep this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Could I know the initiative behind it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since square() is not monotonous, I personally prefer 'exp()', which is more useful when we try to learn a robust loss radius. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh! That's great! Because at a very early time, I tried to let the system automatically learn the Radius, and I failed at that time (performance was not good). But with the new version, the training outcome may change. |
||
@classmethod | ||
def linearize(cls, x: torch.Tensor, log_radius: torch.Tensor) -> torch.Tensor: | ||
return cls._linearize_impl(x, log_radius.exp()) | ||
return cls._linearize_impl(x, log_radius.square()) | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
|
@@ -49,4 +50,4 @@ def _evaluate_impl(x: torch.Tensor, radius: torch.Tensor) -> torch.Tensor: | |
|
||
@staticmethod | ||
def _linearize_impl(x: torch.Tensor, radius: torch.Tensor) -> torch.Tensor: | ||
return torch.sqrt(radius / torch.max(x, radius) + _LOSS_EPS) | ||
return torch.sqrt(radius / torch.max(x, radius) + _LOSS_EPS) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing line at end of file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep
python
naming conventions usinglower_case
for variables instead ofRobustSum
. Not sure what a good name for this should be; let's start withflatten_dims: bool = False
, which defaults to the old version, andTrue
means the new version.