Skip to content

Commit

Permalink
New NonlinearOptimizer hierarchy (#440)
Browse files Browse the repository at this point in the history
* Moved NLLS-specific methods from NonlinearOptimizer class to NLLS subclass.

* Moved linear solver usage from NonlinearOptim class to NLLS subclass.

* Moved back BackwardMode, _merge_infos, and _split_backward_iters to base class.

* Moved error metric to Objective (#455)

* Deprecated Objective.error_squared_norm() for configurable error_metric().

* Made error metric configurable, and remove _error_metric method from NL optimizer class.
  • Loading branch information
luisenp authored Feb 6, 2023
1 parent 1dd5046 commit 16c8740
Show file tree
Hide file tree
Showing 19 changed files with 474 additions and 430 deletions.
6 changes: 3 additions & 3 deletions examples/motion_planning_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_learning_loop(cfg):
motion_planner.objective.update(planner_inputs)
initial_trajectory = motion_planner.get_trajectory()
with torch.no_grad():
batch_error = motion_planner.objective.error_squared_norm().mean() / 2
batch_error = motion_planner.objective.error_metric().mean()
print(f"Planner MSE optim first: {batch_error.item() : 10.2f}")

_, info = motion_planner.layer.forward(
Expand All @@ -163,14 +163,14 @@ def run_learning_loop(cfg):
)

with torch.no_grad():
batch_error = motion_planner.objective.error_squared_norm().mean() / 2
batch_error = motion_planner.objective.error_metric().mean()
print(f"Planner MSE optim final: {batch_error.item() : 10.2f}")

if cfg.do_learning:
gp_error, collision_error = motion_planner.get_total_squared_errors()
loss = 0
if cfg.use_mean_objective_as_loss:
loss = motion_planner.objective.error_squared_norm().mean()
loss = motion_planner.objective.error_metric().mean()
loss /= motion_planner.objective.dim()
loss *= cfg.obj_loss_weight
epoch_mean_objective_loss += loss.item()
Expand Down
4 changes: 1 addition & 3 deletions examples/pose_graph/pose_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def main(cfg):
log.info(f"Forward pass used {forward_mem} MBs.")

results = {}
results["objective"] = (
objective.error_squared_norm().detach().cpu().numpy().sum() / 2
)
results["objective"] = objective.error_metric().detach().cpu().numpy().sum()
results["R"] = torch.cat(
[pose.tensor[:, :, :d].detach().cpu() for pose in verts]
).numpy()
Expand Down
2 changes: 1 addition & 1 deletion examples/state_estimation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def cost_weights_model():
objective.update(theseus_inputs)
with torch.no_grad():
if epoch % 10 == 0:
print("Initial error:", objective.error_squared_norm().mean().item())
print("Initial error:", objective.error_metric().mean().item())

for i in range(inner_loop_iters):
theseus_inputs, _ = state_estimator.forward(
Expand Down
10 changes: 5 additions & 5 deletions tests/core/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ def _check_error_for_data(v1_data_, v2_data_, error_, error_type):
expected_error = torch.cat([v1_data_, v2_data_], dim=1) * w

if error_type == "error":
assert error_.allclose(expected_error)
torch.testing.assert_close(error_, expected_error)
else:
assert error_.allclose(expected_error.norm(dim=1) ** 2)
torch.testing.assert_close(error_, 0.5 * (expected_error.norm(dim=1) ** 2))

def _check_variables(objective, input_tensors, v1_data, v2_data, also_update):
if also_update:
Expand Down Expand Up @@ -289,7 +289,7 @@ def _check_error_and_variables(
objective.update({"v1": v1_data, "v2": v2_data})
error = objective.error()
_check_error_for_data(v1_data, v2_data, error, "error")
error_norm_2 = objective.error_squared_norm()
error_norm_2 = objective.error_metric()

assert error.shape == (batch_size, 2 * dof)
_check_error_for_data(v1_data, v2_data, error_norm_2, "error_norm_2")
Expand All @@ -316,7 +316,7 @@ def _check_error_and_variables(

input_tensors = {"v1": v1_data_new, "v2": v2_data_new}

error_norm_2 = objective.error_squared_norm(
error_norm_2 = objective.error_metric(
input_tensors=input_tensors, also_update=False
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def _check_error_and_variables(

input_tensors = {"v1": v1_data_new, "v2": v2_data_new}

error_norm_2 = objective.error_squared_norm(
error_norm_2 = objective.error_metric(
input_tensors=input_tensors, also_update=True
)

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def _solve_fn_for_masked_jacobians(
for _ in range(5): # do a few steps
optim.zero_grad()
layer.forward(input_tensors)
loss = obj.error_squared_norm().sum()
loss = obj.error_metric().sum()
loss.backward()
optim.step()

Expand Down
6 changes: 3 additions & 3 deletions tests/optimizer/nonlinear/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _check_info(info, batch_size, max_iterations, initial_error, objective):
assert info.err_history.shape == (batch_size, max_iterations + 1)
assert info.err_history[:, 0].allclose(initial_error)
assert info.err_history.argmin(dim=1).allclose(info.best_iter + 1)
last_error = objective.error_squared_norm() / 2
last_error = objective.error_metric()
last_convergence_idx = info.converged_iter.max().item()
assert info.err_history[:, last_convergence_idx].allclose(last_error)

Expand Down Expand Up @@ -141,7 +141,7 @@ def _check_nonlinear_least_squares_fit(
# Initial value is B = [0, 1, ..., nvars - 1]
values = {"coefficients": torch.arange(nvars).repeat(batch_size, 1).float()}
objective.update(values)
initial_error = objective.error_squared_norm() / 2
initial_error = objective.error_metric()
max_iterations = 20
optimizer = nonlinear_optim_cls(objective)
assert isinstance(optimizer.linear_solver, th.CholeskyDenseSolver)
Expand Down Expand Up @@ -195,7 +195,7 @@ def _check_nonlinear_least_squares_fit_multivar(
# Initial value is B = [0, 1, ..., nvars - 1]
values = dict((f"coeff{i}", i * torch.ones(batch_size, 1)) for i in range(nvars))
objective.update(values)
initial_error = objective.error_squared_norm() / 2
initial_error = objective.error_metric()

max_iterations = 20
optimizer = nonlinear_optim_cls(objective)
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizer/nonlinear/test_trust_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _rand_w():
opt = th.Dogleg(o, linear_solver_cls=linear_solver_cls)
o._resolve_batch_size()
opt.linear_solver.linearization.linearize()
previous_err = opt._error_metric()
previous_err = opt.objective.error_metric()

# Check rho = 1. Predicted error by TrustRegion method should
# match actual error after step for a linear problem
Expand Down
8 changes: 5 additions & 3 deletions tests/test_dlm_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _original_dlm_perturbation(optim_vars, aux_vars):
v = optim_vars[0]
g = aux_vars[0]
epsilon = aux_vars[1]
return epsilon.tensor * v.tensor - 0.5 * g.tensor
return (epsilon.tensor * v.tensor - 0.5 * g.tensor) * np.sqrt(2)


def test_dlm_perturbation_jacobian():
Expand Down Expand Up @@ -68,8 +68,10 @@ def new_error_fn(vars):
aux_vars=[grad, epsilon],
)
original_jac, original_err = original_cf.jacobians()
assert error.allclose(original_err)
assert jacobians[0].allclose(original_jac[0], atol=1e-5)
torch.testing.assert_close(error, original_err)
torch.testing.assert_close(
jacobians[0], original_jac[0], atol=1e-5, rtol=1e-5
)


def test_backward_pass_se3_runs():
Expand Down
53 changes: 47 additions & 6 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Protocol,
Sequence,
Union,
)

import torch

Expand All @@ -19,11 +28,24 @@
from .variable import Variable


class ErrorMetric(Protocol):
def __call__(self, error_vector: torch.Tensor) -> torch.Tensor:
pass


def error_squared_norm_fn(error_vector: torch.Tensor) -> torch.Tensor:
return (error_vector**2).sum(dim=1) / 2


# If dtype is None, uses torch.get_default_dtype()
class Objective:
"""An objective function to optimize."""

def __init__(self, dtype: Optional[torch.dtype] = None):
def __init__(
self,
dtype: Optional[torch.dtype] = None,
error_metric_fn: Optional[ErrorMetric] = None,
):
# maps variable names to the variable objects
self.optim_vars: OrderedDict[str, Manifold] = OrderedDict()

Expand Down Expand Up @@ -100,6 +122,12 @@ def __init__(self, dtype: Optional[torch.dtype] = None):

self._vectorized = False

# Computes an aggregation function for the error vector derived from costs
# By default, this computes the squared norm of the error vector, divided by 2
self._error_metric_fn = (
error_metric_fn if error_metric_fn is not None else error_squared_norm_fn
)

def _add_function_variables(
self,
function: TheseusFunction,
Expand Down Expand Up @@ -425,14 +453,27 @@ def error(
self.update(old_tensors, _update_vectorization=False)
return error_vector

def error_metric(
self,
input_tensors: Optional[Dict[str, torch.Tensor]] = None,
also_update: bool = False,
) -> torch.Tensor:
return self._error_metric_fn(
self.error(input_tensors=input_tensors, also_update=also_update)
)

def error_squared_norm(
self,
input_tensors: Optional[Dict[str, torch.Tensor]] = None,
also_update: bool = False,
) -> torch.Tensor:
return (
self.error(input_tensors=input_tensors, also_update=also_update) ** 2
).sum(dim=1)
warnings.warn(
"Objective.error_squared_norm() is deprecated "
"and will be removed in future versions. "
"Please use Objective.error_metric() instead.",
DeprecationWarning,
)
return self.error_metric(input_tensors=input_tensors, also_update=also_update)

def copy(self) -> "Objective":
new_objective = Objective(dtype=self.dtype)
Expand Down
2 changes: 1 addition & 1 deletion theseus/optimizer/nonlinear/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _complete_step(
**kwargs,
) -> Optional[torch.Tensor]:
# "err" tensors passed as input refer to the squared norm of the
# error vector, as returned by self._error_metric()
# error vector, as returned by self.objective.error_metric()
if adaptive_damping:
return self._check_accept(
delta,
Expand Down
Loading

0 comments on commit 16c8740

Please sign in to comment.