From 649d3fb22b5d677cf916501da4d474bcfdcc0b50 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 14 Jun 2022 13:17:24 -0400 Subject: [PATCH] Moved the method that retracts all variables with a given delta to Objective (#214) * Moved the method that retracts all variables with a given delta to Objective. * Renamed step_optim_vars as retract_optim_vars. --- theseus/core/objective.py | 17 ++++++++++++ theseus/optimizer/linear/linear_optimizer.py | 13 +++------- .../nonlinear/nonlinear_optimizer.py | 26 ++++--------------- theseus/theseus_layer.py | 4 +-- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/theseus/core/objective.py b/theseus/core/objective.py index b92e93313..b387ac7c8 100644 --- a/theseus/core/objective.py +++ b/theseus/core/objective.py @@ -503,3 +503,20 @@ def to(self, *args, **kwargs): self.dtype = dtype or self.dtype if self._vectorization_to is not None: self._vectorization_to(*args, **kwargs) + + def retract_optim_vars( + self, + delta: torch.Tensor, + ordering: Iterable[Manifold], + ignore_mask: Optional[torch.Tensor] = None, + force_update: bool = False, + ): + var_idx = 0 + for var in ordering: + new_var = var.retract(delta[:, var_idx : var_idx + var.dof()]) + if ignore_mask is None or force_update: + var.update(new_var.data) + else: + var.update(new_var.data, batch_ignore_mask=ignore_mask) + var_idx += var.dof() + self.update_vectorization() diff --git a/theseus/optimizer/linear/linear_optimizer.py b/theseus/optimizer/linear/linear_optimizer.py index 0e75164ac..8fc239a29 100644 --- a/theseus/optimizer/linear/linear_optimizer.py +++ b/theseus/optimizer/linear/linear_optimizer.py @@ -69,17 +69,10 @@ def _optimize_impl( warnings.warn(msg, RuntimeWarning) info.status[:] = LinearOptimizerStatus.FAIL return info - self.retract_and_update_variables(delta) + self.objective.retract_optim_vars( + delta, self.linear_solver.linearization.ordering + ) info.status[:] = LinearOptimizerStatus.CONVERGED for var in self.linear_solver.linearization.ordering: info.best_solution[var.name] = var.data.clone().cpu() return info - - # retracts all variables in the given order and updates their values - # with the result - def retract_and_update_variables(self, delta: torch.Tensor): - var_idx = 0 - for var in self.linear_solver.linearization.ordering: - new_var = var.retract(delta[:, var_idx : var_idx + var.dof()]) - var.update(new_var.data) - var_idx += var.dof() diff --git a/theseus/optimizer/nonlinear/nonlinear_optimizer.py b/theseus/optimizer/nonlinear/nonlinear_optimizer.py index 3ab1539d1..53859d6ae 100644 --- a/theseus/optimizer/nonlinear/nonlinear_optimizer.py +++ b/theseus/optimizer/nonlinear/nonlinear_optimizer.py @@ -262,8 +262,11 @@ def _optimize_loop( step_size = self.params.step_size force_update = False - self.retract_and_update_variables( - delta, converged_indices, step_size, force_update=force_update + self.objective.retract_optim_vars( + delta * step_size, + self.linear_solver.linearization.ordering, + ignore_mask=converged_indices, + force_update=force_update, ) # check for convergence @@ -376,22 +379,3 @@ def _optimize_impl( @abc.abstractmethod def compute_delta(self, **kwargs) -> torch.Tensor: pass - - # retracts all variables in the given order and updates their values - # with the result - def retract_and_update_variables( - self, - delta: torch.Tensor, - converged_indices: torch.Tensor, - step_size: float, - force_update: bool = False, - ): - var_idx = 0 - delta = step_size * delta - for var in self.linear_solver.linearization.ordering: - new_var = var.retract(delta[:, var_idx : var_idx + var.dof()]) - if force_update: - var.update(new_var.data) - else: - var.update(new_var.data, batch_ignore_mask=converged_indices) - var_idx += var.dof() diff --git a/theseus/theseus_layer.py b/theseus/theseus_layer.py index 0b0d3722d..b58188318 100644 --- a/theseus/theseus_layer.py +++ b/theseus/theseus_layer.py @@ -218,8 +218,8 @@ def backward(ctx, *grad_outputs): with torch.no_grad(): bwd_optimizer.linear_solver.linearization.linearize() delta = bwd_optimizer.linear_solver.solve() - bwd_optimizer.retract_and_update_variables( - delta, None, 1.0, force_update=True + bwd_optimizer.objective.retract_optim_vars( + delta, bwd_optimizer.linear_solver.linearization.ordering ) # Compute gradients.