Skip to content

Commit

Permalink
Moved Vectorize(objective) to the Optimizer class. (#218)
Browse files Browse the repository at this point in the history
* Moved Vectorize(objective) to the Optimizer class.

* Vectorized tests for nonlinear optimizers.

* Small fixes.

* Fix flaky unit test.

* Added option to vectorize objective from TheseusLayer.
  • Loading branch information
luisenp authored Jun 16, 2022
1 parent 7329a78 commit ce6e127
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 31 deletions.
42 changes: 41 additions & 1 deletion theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import warnings
from collections import OrderedDict
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch

Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None):
# If vectorization is on, this gets replaced by a vectorized version
self._retract_method = Objective._retract_base

self._vectorized = False

def _add_function_variables(
self,
function: TheseusFunction,
Expand Down Expand Up @@ -534,3 +536,41 @@ def retract_optim_vars(
self._retract_method(
delta, ordering, ignore_mask=ignore_mask, force_update=force_update
)

def _enable_vectorization(
self,
cost_fns_iter: Iterable[CostFunction],
vectorization_run_fn: Callable,
vectorized_to: Callable,
vectorized_retract_fn: Callable,
enabler: Any,
):
# Hacky way to make Vectorize a "friend" class
assert (
enabler.__module__ == "theseus.core.vectorizer"
and enabler.__class__.__name__ == "Vectorize"
)
self._cost_functions_iterable = cost_fns_iter
self._vectorization_run = vectorization_run_fn
self._vectorization_to = vectorized_to
self._retract_method = vectorized_retract_fn
self._vectorized = True

# Making public, since this should be a safe operation
def disable_vectorization(self):
self._cost_functions_iterable = None
self._vectorization_run = None
self._vectorization_to = None
self._retract_method = Objective._retract_base
self._vectorized = False

@property
def vectorized(self):
assert (
(not self._vectorized)
== (self._cost_functions_iterable is None)
== (self._vectorization_run is None)
== (self._vectorization_to is None)
== (self._retract_method is Objective._retract_base)
)
return self._vectorized
4 changes: 2 additions & 2 deletions theseus/core/tests/test_robust_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


def _new_robust_cf(batch_size, loss_cls, generator) -> th.RobustCostFunction:
v1 = th.rand_se3(batch_size)
v2 = th.rand_se3(batch_size)
v1 = th.rand_se3(batch_size, generator=generator)
v2 = th.rand_se3(batch_size, generator=generator)
w = th.ScaleCostWeight(torch.randn(1, generator=generator))
cf = th.Local(v1, w, v2)
ll_radius = th.Variable(data=torch.randn(1, 1, generator=generator))
Expand Down
15 changes: 9 additions & 6 deletions theseus/core/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,17 @@ def __init__(self, objective: Objective):
# Dict[_CostFunctionSchema, List[str]]
self._var_names = self._get_var_names()

# `vectorize()` will compute an error vector for each schema, then populate
# the wrappers with their appropriate weighted error slice.
# `self._vectorize()` will compute an error vector for each schema,
# then populate the wrappers with their appropriate weighted error slice.
# Replacing `obj._cost_functions_iterable` allows to recover these when
# iterating the Objective.
objective._cost_functions_iterable = self._cost_fn_wrappers
objective._vectorization_run = self._vectorize
objective._vectorization_to = self._to
objective._retract_method = self._vectorized_retract_optim_vars
objective._enable_vectorization(
self._cost_fn_wrappers,
self._vectorize,
self._to,
self._vectorized_retract_optim_vars,
self,
)

self._objective = objective

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def ee_pose_err_fn(optim_vars, aux_vars):
max_iterations=15,
step_size=0.5,
)
theseus_optim = th.TheseusLayer(optimizer, vectorize=True)
theseus_optim = th.TheseusLayer(optimizer)

# Optimize
theseus_inputs = {
Expand Down
3 changes: 2 additions & 1 deletion theseus/optimizer/linear/linear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ def __init__(
objective: Objective,
linear_solver_cls: Type[LinearSolver],
*args,
vectorize: bool = True,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(objective)
super().__init__(objective, vectorize=vectorize)
linearization_kwargs = linearization_kwargs or {}
linear_solver_kwargs = linear_solver_kwargs or {}
self.linear_solver = linear_solver_cls(
Expand Down
2 changes: 2 additions & 0 deletions theseus/optimizer/nonlinear/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self,
objective: Objective,
linear_solver_cls: Optional[Type[LinearSolver]] = None,
vectorize: bool = True,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -30,6 +31,7 @@ def __init__(
super().__init__(
objective,
linear_solver_cls=linear_solver_cls,
vectorize=vectorize,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
linear_solver_kwargs=linear_solver_kwargs,
Expand Down
2 changes: 2 additions & 0 deletions theseus/optimizer/nonlinear/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self,
objective: Objective,
linear_solver_cls: Optional[Type[LinearSolver]] = None,
vectorize: bool = True,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -42,6 +43,7 @@ def __init__(
super().__init__(
objective,
linear_solver_cls=linear_solver_cls,
vectorize=vectorize,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
linear_solver_kwargs=linear_solver_kwargs,
Expand Down
2 changes: 2 additions & 0 deletions theseus/optimizer/nonlinear/nonlinear_least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
objective: Objective,
*args,
linear_solver_cls: Optional[Type[LinearSolver]] = None,
vectorize: bool = True,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -35,6 +36,7 @@ def __init__(
super().__init__(
objective,
linear_solver_cls=linear_solver_cls,
vectorize=vectorize,
linearization_cls=linearization_cls,
linearization_kwargs=linearization_kwargs,
linear_solver_kwargs=linear_solver_kwargs,
Expand Down
3 changes: 2 additions & 1 deletion theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
objective: Objective,
linear_solver_cls: Type[LinearSolver],
*args,
vectorize: bool = True,
linearization_cls: Optional[Type[Linearization]] = None,
linearization_kwargs: Optional[Dict[str, Any]] = None,
linear_solver_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -73,7 +74,7 @@ def __init__(
step_size: float = 1.0,
**kwargs,
):
super().__init__(objective)
super().__init__(objective, vectorize=vectorize)
linear_solver_kwargs = linear_solver_kwargs or {}
self.linear_solver = linear_solver_cls(
objective,
Expand Down
59 changes: 46 additions & 13 deletions theseus/optimizer/nonlinear/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,79 @@ def __init__(
point=None,
multivar=False,
noise_mag=0,
target=None,
):
super().__init__(cost_weight, name=name)
len_vars = len(optim_vars) if multivar else optim_vars[0].dof()
assert true_coeffs.ndim == 1 and true_coeffs.numel() == len_vars
if point.ndim == 1:
point = point.unsqueeze(0)
assert point.ndim == 2 and point.shape[1] == len_vars - 1
self._optim_vars = optim_vars

if isinstance(true_coeffs, torch.Tensor):
assert true_coeffs.ndim == 1 and true_coeffs.numel() == len_vars
self.true_coeffs = th.Variable(data=true_coeffs.unsqueeze(0))
else:
self.true_coeffs = true_coeffs

batch_size = point.shape[0]
if isinstance(point, torch.Tensor):
if point.ndim == 1:
point = point.unsqueeze(0)
assert point.ndim == 2 and point.shape[1] == len_vars - 1
self.point = th.Variable(
data=torch.cat([point, torch.ones(batch_size, 1)], dim=1)
)
else:
self.point = point
for i, var in enumerate(optim_vars):
attr_name = f"optim_var_{i}"
setattr(self, attr_name, var)
self.register_optim_var(attr_name)
self.register_aux_var("true_coeffs")
self.register_aux_var("point")

if target is None:
target_data = (self.point.data * self.true_coeffs.data).sum(
1, keepdim=True
) ** 2
if noise_mag:
target_data += noise_mag * torch.randn(size=target_data.shape)
self.target = th.Variable(data=target_data)
else:
self.target = target
self.register_aux_var("target")

self.point = torch.cat([point, torch.ones(batch_size, 1)], dim=1)
self.target = (self.point * true_coeffs.unsqueeze(0)).sum(1, keepdim=True) ** 2
if noise_mag:
self.target += noise_mag * torch.randn(size=self.target.shape)
self.noise_mag = noise_mag
self.multivar = multivar

def _eval_coeffs(self):
if self.multivar:
coeffs = torch.cat([v.data for v in self.optim_vars], axis=1)
else:
coeffs = self.optim_var_0.data
return (self.point * coeffs).sum(1, keepdim=True)
return (self.point.data * coeffs).sum(1, keepdim=True)

def error(self):
# h(B * x) - h(Btrue * x)
return self._eval_coeffs() ** 2 - self.target
return self._eval_coeffs() ** 2 - self.target.data

def jacobians(self):
dhdz = 2 * self._eval_coeffs()
grad = self.point * dhdz
grad = self.point.data * dhdz
return [grad.unsqueeze(1)], self.error()

def dim(self):
return 1

def _copy_impl(self, new_name=None):
raise NotImplementedError
return ResidualCostFunction(
[v.copy() for v in self._optim_vars],
self.weight.copy(),
name=new_name,
true_coeffs=self.true_coeffs.copy(),
point=self.point.copy(),
multivar=self.multivar,
noise_mag=self.noise_mag,
target=self.target.copy(),
)


def _check_info(info, batch_size, max_iterations, initial_error, objective):
Expand Down Expand Up @@ -207,7 +240,7 @@ def _copy_impl(self):
values = {"dummy": torch.zeros(1, 1)}
objective.update(values)

optimizer = nonlinear_optim_cls(objective)
optimizer = nonlinear_optim_cls(objective, vectorize=False)
assert isinstance(optimizer.linear_solver, th.CholeskyDenseSolver)
optimizer.set_params(max_iterations=30)
optimizer.linear_solver.linearization = BadLinearization(objective)
Expand Down
2 changes: 1 addition & 1 deletion theseus/optimizer/nonlinear/tests/test_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def quad_error_fn(optim_vars, aux_vars):
"x": data_x,
"y": data_y,
}
theseus_optim = th.TheseusLayer(optimizer, vectorize=True)
theseus_optim = th.TheseusLayer(optimizer)


def test_backwards():
Expand Down
6 changes: 4 additions & 2 deletions theseus/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import torch

from theseus.core import Objective
from theseus.core import Objective, Vectorize


# All info information is batched
Expand All @@ -22,8 +22,10 @@ class OptimizerInfo:


class Optimizer(abc.ABC):
def __init__(self, objective: Objective, *args, **kwargs):
def __init__(self, objective: Objective, *args, vectorize: bool = True, **kwargs):
self.objective = objective
if vectorize:
Vectorize(self.objective)
self._objectives_version = objective.current_version

@abc.abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions theseus/tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ def error_fn(optim_vars, aux_vars):

optimizer = nonlinear_optimizer_cls(
objective,
vectorize=False,
linear_solver_cls=linear_solver_cls,
max_iterations=max_iterations,
)
assert isinstance(optimizer.linear_solver, linear_solver_cls)
assert not objective.vectorized
theseus_layer = th.TheseusLayer(optimizer, vectorize=True)
assert objective.vectorized
return theseus_layer


Expand Down
5 changes: 2 additions & 3 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ class TheseusLayer(nn.Module):
def __init__(self, optimizer: Optimizer, vectorize: bool = True):
super().__init__()
self.objective = optimizer.objective
if vectorize and not self.objective.vectorized:
Vectorize(self.objective)
self.optimizer = optimizer
self._objectives_version = optimizer.objective.current_version
if vectorize:
Vectorize(self.objective)

self._dlm_bwd_objective = None
self._dlm_bwd_optimizer = None

Expand Down

0 comments on commit ce6e127

Please sign in to comment.