Skip to content

Commit

Permalink
Add proper accept/reject logic for LM optimizer (#364)
Browse files Browse the repository at this point in the history
* Refactored optimizer retract step so that retract and update happen separately.

* Changed NonlinearOptimizer.step() so that it returns the error.

* Renamed retract_optim_vars to a more descriptive name.

* Changed Objective.update() so that it can also be given an ignore mask.

* Replaced update_optimizer_state by a method called _complete_step, that also return optional reject indices.

* Made _step private. Added some comments.
  • Loading branch information
luisenp authored Nov 28, 2022
1 parent 40dd795 commit abffd1a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 57 deletions.
42 changes: 33 additions & 9 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __deepcopy__(self, memo):
memo[id(self)] = the_copy
return the_copy

def update(self, input_tensors: Optional[Dict[str, torch.Tensor]] = None):
def _resolve_batch_size(self):
self._batch_size = None

def _get_batch_size(batch_sizes: Sequence[int]) -> int:
Expand All @@ -499,6 +499,19 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
return max_bs
raise ValueError("Provided tensors must be broadcastable.")

batch_sizes = [v.tensor.shape[0] for v in self.optim_vars.values()]
batch_sizes.extend([v.tensor.shape[0] for v in self.aux_vars.values()])
self._batch_size = _get_batch_size(batch_sizes)

# batch_ignore_mask is a boolean list where batch_ignore_mask[i] = 1 means
# for any variable v, v[i] will *not* be updated. Shape must be equal to the
# batch size.
def update(
self,
input_tensors: Optional[Dict[str, torch.Tensor]] = None,
batch_ignore_mask: Optional[torch.Tensor] = None,
):

input_tensors = input_tensors or {}
for var_name, tensor in input_tensors.items():
if tensor.ndim < 2:
Expand All @@ -508,11 +521,17 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
f"tensor with name {var_name}."
)
if var_name in self.optim_vars:
self.optim_vars[var_name].update(tensor)
self.optim_vars[var_name].update(
tensor, batch_ignore_mask=batch_ignore_mask
)
elif var_name in self.aux_vars:
self.aux_vars[var_name].update(tensor)
self.aux_vars[var_name].update(
tensor, batch_ignore_mask=batch_ignore_mask
)
elif var_name in self.cost_weight_optim_vars:
self.cost_weight_optim_vars[var_name].update(tensor)
self.cost_weight_optim_vars[var_name].update(
tensor, batch_ignore_mask=batch_ignore_mask
)
warnings.warn(
"Updated a variable declared as optimization, but it is "
"only associated to cost weights and not to any cost functions. "
Expand All @@ -526,9 +545,8 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
)

# Check that the batch size of all tensors is consistent after update
batch_sizes = [v.tensor.shape[0] for v in self.optim_vars.values()]
batch_sizes.extend([v.tensor.shape[0] for v in self.aux_vars.values()])
self._batch_size = _get_batch_size(batch_sizes)
self._resolve_batch_size()
self.update_vectorization_if_needed()

def _vectorization_needs_update(self):
num_updates = {name: v._num_updates for name, v in self._all_variables.items()}
Expand All @@ -545,7 +563,7 @@ def _vectorization_needs_update(self):
def update_vectorization_if_needed(self):
if self.vectorized and self._vectorization_needs_update():
if self._batch_size is None:
self.update()
self._resolve_batch_size()
self._vectorization_run()
self._last_vectorization_has_grad = torch.is_grad_enabled()

Expand Down Expand Up @@ -589,7 +607,13 @@ def _retract_base(
var.update(new_var.tensor, batch_ignore_mask=ignore_mask)
var_idx += var.dof()

def retract_optim_vars(
# Retracts an ordered sequence of variables according to the
# corresponding `delta` tangent vectors.
# This function assumes that delta is constructed as follows:
# For ordering = [v1 v2 ... vn]
# delta = torch.cat([delta_v1, delta_v2, ..., delta_vn], dim=-1)
# where delta_vi.shape = (batch_size, vi.dof)
def retract_vars_sequence(
self,
delta: torch.Tensor,
ordering: Iterable[Manifold],
Expand Down
2 changes: 1 addition & 1 deletion theseus/optimizer/linear/linear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _optimize_impl(
warnings.warn(msg, RuntimeWarning)
info.status[:] = LinearOptimizerStatus.FAIL
return info
self.objective.retract_optim_vars(
self.objective.retract_vars_sequence(
delta, self.linear_solver.linearization.ordering
)
info.status[:] = LinearOptimizerStatus.CONVERGED
Expand Down
4 changes: 2 additions & 2 deletions theseus/optimizer/nonlinear/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
rel_err_tolerance: float = 1e-8,
max_iterations: int = 20,
step_size: float = 1.0,
**kwargs
**kwargs,
):
super().__init__(
objective,
Expand All @@ -40,7 +40,7 @@ def __init__(
rel_err_tolerance=rel_err_tolerance,
max_iterations=max_iterations,
step_size=step_size,
**kwargs
**kwargs,
)

def compute_delta(self, **kwargs) -> torch.Tensor:
Expand Down
53 changes: 39 additions & 14 deletions theseus/optimizer/nonlinear/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,37 +133,62 @@ def compute_delta(
damping_eps=damping_eps,
)

# Updates damping per batch element depending on whether the last step
# was successful in decreasing error or not.
# Based on https://people.duke.edu/~hpgavin/ce281/lm.pdf, Section 4.1
# We currently use method (1) from 4.1.1
def _update_state_impl(
def _complete_step(
self,
last_err: torch.Tensor,
new_err: torch.Tensor,
delta: torch.Tensor,
new_err: torch.Tensor,
previous_err: torch.Tensor,
adaptive_damping: bool = False,
down_damping_ratio: float = 9.0,
up_damping_ratio: float = 11.0,
damping_accept: float = 0.1,
**kwargs,
) -> None:
if not adaptive_damping:
return
) -> Optional[torch.Tensor]:
if adaptive_damping:
return self._check_accept(
delta,
new_err,
previous_err,
damping_accept,
down_damping_ratio,
up_damping_ratio,
)
else:
return None

# Checks if the step should be accepted (per batch element)
# Adjusts self._damping accordingly
# Returns a mask indicating which batch indices were accepted
#
# Based on https://people.duke.edu/~hpgavin/ce281/lm.pdf, Section 4.1
# We currently use method (1) from 4.1.1
@torch.no_grad()
def _check_accept(
self,
delta: torch.Tensor,
err: torch.Tensor,
previous_err: torch.Tensor,
damping_accept: float,
down_damping_ratio: float,
up_damping_ratio: float,
) -> torch.Tensor:
linearization = self.linear_solver.linearization
damping = (
self._damping.view(-1, 1)
if isinstance(self._damping, torch.Tensor)
else self._damping
)
# Deliberately using Atb before updating the variables, according to
# the LM reference above
den = (delta * (damping * delta + linearization.Atb.squeeze(2))).sum(dim=1)
rho = (last_err - new_err) / den
good_idx = rho > damping_accept
rho = (previous_err - err) / den
reject_indices = rho <= damping_accept
self._damping = torch.where(
good_idx,
self._damping / down_damping_ratio,
reject_indices,
self._damping * up_damping_ratio,
self._damping / down_damping_ratio,
)
self._damping = self._damping.clamp(
LevenbergMarquardt._MIN_DAMPING, LevenbergMarquardt._MAX_DAMPING
)
return reject_indices
117 changes: 89 additions & 28 deletions theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ def resolve(key: Union[str, "BackwardMode"]) -> "BackwardMode":
]


# Base class for all nonlinear optimizers, providing the skeleton of the
# optimization loop. Subclasses need to implement the following method:
#
# - `compute_delta`: returns a descent direction given the current values
# of the objective's optimization vars.
#
# Optionally, they can also provide the following methods:
#
# - `reset`: resets any internal state needed by the optimizer.
# - `_complete_step`: called at the end of an optimization step, but before
# optimization variables are updated. Returns batch indices that should not
# any be updated (e.g., if the step is to be rejected).
#
# The high level logic of a call to optimize is as follows:
#
# prev_err = objective.error_squared_norm()
# do optimization loop:
# 1. compute delta
# 2. step(delta, prev_err)
# 2.1. Store current optim var tensors in tmp_optim_vars containers
# 2.2. Retract all tmp_optim_vars given delta
# 2.3. Evaluate new error
# 2.4. reject_indices = self._complete_step(delta, new_err, prev_err)
# 2.5. Update objective's optim var containers with retracted values,
# ignoring indices given by `reject_indices`
# 3. Check convergence
class NonlinearOptimizer(Optimizer, abc.ABC):
def __init__(
self,
Expand All @@ -104,9 +130,11 @@ def __init__(
linearization_kwargs=linearization_kwargs,
**linear_solver_kwargs,
)
self.ordering = self.linear_solver.linearization.ordering
self.params = NonlinearOptimizerParams(
abs_err_tolerance, rel_err_tolerance, max_iterations, step_size
)
self._tmp_optim_vars = tuple(v.copy(new_name=v.name) for v in self.ordering)

def set_params(self, **kwargs):
self.params.update(kwargs)
Expand All @@ -128,7 +156,7 @@ def _maybe_init_best_solution(
if not do_init:
return None
solution_dict = {}
for var in self.linear_solver.linearization.ordering:
for var in self.ordering:
solution_dict[var.name] = var.tensor.detach().clone().cpu()
return solution_dict

Expand Down Expand Up @@ -204,7 +232,7 @@ def _update_info(
assert info.best_err is not None
good_indices = err < info.best_err
info.best_iter[good_indices] = current_iter
for var in self.linear_solver.linearization.ordering:
for var in self.ordering:
info.best_solution[var.name][good_indices] = (
var.tensor.detach().clone()[good_indices].cpu()
)
Expand Down Expand Up @@ -327,20 +355,20 @@ def _optimize_loop(
with torch.no_grad():
if steps_tensor is None:
steps_tensor = torch.ones_like(delta) * self.params.step_size
self.objective.retract_optim_vars(

# For now, step size is combined with delta. If we add more sophisticated
# line search, will probably need to pass it separately, or compute inside.
err = self._step(
delta * steps_tensor,
self.linear_solver.linearization.ordering,
ignore_mask=converged_indices,
force_update=force_update,
)
info.last_err,
converged_indices,
force_update,
**kwargs,
) # err is shape (batch_size,)

# check for convergence
with torch.no_grad():
err = self.objective.error_squared_norm() / 2
self._update_info(info, it_, err, converged_indices)
self.update_optimizer_state(
last_err=info.last_err, new_err=err, delta=delta, **kwargs
)
if verbose:
print(
f"Nonlinear optimizer. Iteration: {it_+1}. "
Expand Down Expand Up @@ -460,30 +488,63 @@ def _optimize_impl(
def compute_delta(self, **kwargs) -> torch.Tensor:
pass

# Adds references to the current optim variable tensors in the the optimizer's
# _tmp_optim_varscontainers. This allow us to compute t_next = V.tensor + delta for
# any optimization variable, without changing the permanent optim var objects
# in the objective.
def _update_tmp_optim_vars(self):
for v_tmp, v_order in zip(self._tmp_optim_vars, self.ordering):
v_tmp.update(v_order.tensor)

# Given descent directions and step sizes, updates the optimization
# variables.
# Batch indices indicated by `converged_indices` mask are ignored
# unless `force_update = True`.
# Returns the total error tensor after the update
def _step(
self,
delta: torch.Tensor,
previous_err: torch.Tensor,
converged_indices: torch.Tensor,
force_update: bool,
**kwargs,
) -> torch.Tensor:
# makes sure tmp containers are up to date with current variables
self._update_tmp_optim_vars()
# stores the result of the retract step in `self._tmp_optim_vars`
self.objective.retract_vars_sequence(
delta,
self._tmp_optim_vars,
ignore_mask=converged_indices,
force_update=force_update,
)
tensor_map = {v.name: v.tensor for v in self._tmp_optim_vars}
with torch.no_grad():
err = self.objective.error_squared_norm(tensor_map, also_update=False)

reject_indices = self._complete_step(delta, err, previous_err, **kwargs)
self.objective.update(tensor_map, batch_ignore_mask=reject_indices)

return err

# Resets any internal state needed by the optimizer for a new optimization
# problem. Optimizer loop will pass all optimizer kwargs to this method.
# Deliberately not abstract, since some optimizers might not need this
def reset(self, **kwargs) -> None:
pass

# Called at the end of every optimizer step to update any internal state
# of the optimizer
@torch.no_grad()
def update_optimizer_state(
# Called at the end of step() but before variables are update to their new values.
# This method can be used to update any internal state of the optimizer and
# also obtain an optional tensor of shape (batch_size,), representing
# a mask of booleans indicating if the step is to be rejected for any
# batch elements.
#
# Deliberately not abstract, since some optimizers might not need this.
def _complete_step(
self,
last_err: torch.Tensor,
new_err: torch.Tensor,
delta: torch.Tensor,
**kwargs,
) -> None:
self._update_state_impl(last_err, new_err, delta, **kwargs)

# Deliberately not abstract, since some optimizers might not need this
def _update_state_impl(
self,
last_err: torch.Tensor,
new_err: torch.Tensor,
delta: torch.Tensor,
previous_err: torch.Tensor,
**kwargs,
) -> None:
pass
) -> Optional[torch.Tensor]:
return None
6 changes: 3 additions & 3 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def backward(ctx, *grad_outputs):
optim_tensors = saved_tensors[n + k + k :]
grad_outputs = grad_outputs[:-1]

bwd_objective = ctx.bwd_objective
bwd_optimizer = ctx.bwd_optimizer
bwd_objective: Objective = ctx.bwd_objective
bwd_optimizer: Optimizer = ctx.bwd_optimizer
epsilon = ctx.epsilon
input_keys = ctx.input_keys

Expand All @@ -246,7 +246,7 @@ def backward(ctx, *grad_outputs):
with torch.no_grad():
bwd_optimizer.linear_solver.linearization.linearize()
delta = bwd_optimizer.linear_solver.solve()
bwd_optimizer.objective.retract_optim_vars(
bwd_optimizer.objective.retract_vars_sequence(
delta, bwd_optimizer.linear_solver.linearization.ordering
)

Expand Down

0 comments on commit abffd1a

Please sign in to comment.