Skip to content

Commit

Permalink
Fixed bug in DLM perturbation jacobians. (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp authored Jul 13, 2022
1 parent 9735a64 commit d107abb
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def forward(
)
optimizer_kwargs = optimizer_kwargs or {}
backward_mode = optimizer_kwargs.get("backward_mode", None)
dlm_epsilon = optimizer_kwargs.get(
TheseusLayerDLMForward._DLM_EPSILON_STR, 1e-2
)
if not isinstance(dlm_epsilon, float):
raise ValueError(
f"{TheseusLayerDLMForward._DLM_EPSILON_STR} must be a float "
f"but {type(dlm_epsilon)} was given."
)
if backward_mode == BackwardMode.DLM:
dlm_epsilon = optimizer_kwargs.get(
TheseusLayerDLMForward._DLM_EPSILON_STR, 1e-2
)
if not isinstance(dlm_epsilon, float):
raise ValueError(
f"{TheseusLayerDLMForward._DLM_EPSILON_STR} must be a float "
f"but {type(dlm_epsilon)} was given."
)

if self._dlm_bwd_objective is None:
_obj, _opt = _instantiate_dlm_bwd_objective(self.objective)
Expand Down Expand Up @@ -283,7 +283,11 @@ def error(self) -> torch.Tensor:

def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
d = self.dim()
aux = torch.eye(d).unsqueeze(0).expand(self.var.shape[0], d, d)
aux = (
torch.eye(d, dtype=self.epsilon.dtype, device=self.epsilon.device)
.unsqueeze(0)
.expand(self.var.shape[0], d, d)
)
euclidean_grad_flat = self.epsilon.tensor.view(-1, 1, 1) * aux
euclidean_grad = euclidean_grad_flat.unflatten(2, self.var.shape[1:])
return [self.var.project(euclidean_grad, is_sparse=True)], self.error()
Expand Down

0 comments on commit d107abb

Please sign in to comment.