Skip to content

Commit

Permalink
DLM gradients (#161)
Browse files Browse the repository at this point in the history
* DLM gradients hacky example

* implement DLM using autograd.Function

* make soln a bit more accurate

* minor; removed unnecessary code

* lower case for dlm_epsilon

* backward test for DLM

* fix imports

* rename and make linter happy

* filter for tensors that require grad

* Construct the bwd objective only once

* minor

* remove print statements

* Fix DLM when using gpu; cost function shape; and handle case when no differentiable tensor

* Fix memory leak by removing dict input_data from input arguments

* preserve ordering

* Expand batch dim if possible

* undo

* use lower case

* reduce a bit of python overhead

* explicit one step
  • Loading branch information
rtqichen authored Jun 7, 2022
1 parent 336059a commit 9e11ced
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 11 deletions.
36 changes: 35 additions & 1 deletion examples/backward_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def quad_error_fn(optim_vars, aux_vars):
optimizer = th.GaussNewton(
objective,
max_iterations=15,
step_size=0.5,
step_size=1.0,
)

theseus_inputs = {
Expand Down Expand Up @@ -128,6 +128,22 @@ def quad_error_fn(optim_vars, aux_vars):
print(da_dx.numpy())


# We can also compute the direct loss minimization gradient.
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.DLM,
"dlm_epsilon": 1e-3,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
print("\n--- backward_mode=DLM")
print(da_dx.numpy())


# Next we numerically check the derivative
def fit_x(data_x_np):
theseus_inputs["x"] = (
Expand Down Expand Up @@ -199,6 +215,21 @@ def fit_x(data_x_np):
].squeeze()
times["bwd_trunc"].append(time.time() - start)

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.DLM,
"dlm_epsilon": 1e-3,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
times["bwd_dlm"].append(time.time() - start)


print("\n=== Runtimes")
k = "fwd"
Expand All @@ -214,3 +245,6 @@ def fit_x(data_x_np):
print(
f"Backward (TRUNCATED, 5 steps) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s"
)

k = "bwd_dlm"
print(f"Backward (DLM) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s")
3 changes: 2 additions & 1 deletion theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BackwardMode(Enum):
FULL = 0
IMPLICIT = 1
TRUNCATED = 2
DLM = 3


class NonlinearOptimizer(Optimizer, abc.ABC):
Expand Down Expand Up @@ -307,7 +308,7 @@ def _optimize_impl(
f"Error: {info.last_err.mean().item()}"
)

if backward_mode == BackwardMode.FULL:
if backward_mode in [BackwardMode.FULL, BackwardMode.DLM]:
info = self._optimize_loop(
start_iter=0,
num_iter=self.params.max_iterations,
Expand Down
16 changes: 15 additions & 1 deletion theseus/optimizer/nonlinear/tests/test_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def quad_error_fn(optim_vars, aux_vars):
optimizer = th.GaussNewton(
objective,
max_iterations=15,
step_size=0.5,
step_size=1.0,
)

theseus_inputs = {
Expand Down Expand Up @@ -119,3 +119,17 @@ def fit_x(data_x_np):
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_truncated, atol=1e-4)

updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.DLM,
"dlm_epsilon": 0.001,
},
)
da_dx_truncated = torch.autograd.grad(
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_truncated, atol=1e-3)
193 changes: 185 additions & 8 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable

from theseus.core import Variable
from theseus.core.cost_function import AutoDiffCostFunction
from theseus.optimizer import Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer.nonlinear import BackwardMode, GaussNewton


class TheseusLayer(nn.Module):
Expand All @@ -22,6 +26,9 @@ def __init__(
self.optimizer = optimizer
self._objectives_version = optimizer.objective.current_version

self._dlm_bwd_objective = None
self._dlm_bwd_optimizer = None

def forward(
self,
input_data: Optional[Dict[str, torch.Tensor]] = None,
Expand All @@ -32,15 +39,38 @@ def forward(
"The objective was modified after the layer's construction, which is "
"currently not supported."
)
self.objective.update(input_data)
optimizer_kwargs = optimizer_kwargs or {}
info = self.optimizer.optimize(**optimizer_kwargs)
values = dict(
[
(var_name, var.data)
for var_name, var in self.objective.optim_vars.items()
]
)
backward_mode = optimizer_kwargs.get("backward_mode", None)
dlm_epsilon = optimizer_kwargs.get(TheseusLayerDLMForward._dlm_epsilon, 1e-2)
if backward_mode == BackwardMode.DLM:

if self._dlm_bwd_objective is None:
_obj, _opt = _instantiate_dlm_bwd_objective(self.objective)
_obj.to(self.device)
self._dlm_bwd_objective = _obj
self._dlm_bwd_optimizer = _opt

# Tensors cannot be passed inside containers, else we run into memory leaks.
input_keys, input_vals = zip(*input_data.items())
differentiable_tensors = [t for t in input_vals if t.requires_grad]

*vars, info = TheseusLayerDLMForward.apply(
self.objective,
self.optimizer,
optimizer_kwargs,
self._dlm_bwd_objective,
self._dlm_bwd_optimizer,
dlm_epsilon,
len(input_keys),
*input_keys,
*input_vals,
*differentiable_tensors,
)
else:
vars, info = _forward(
self.objective, self.optimizer, optimizer_kwargs, input_data
)
values = dict(zip(self.objective.optim_vars.keys(), vars))
return values, info

def compute_samples(
Expand Down Expand Up @@ -93,3 +123,150 @@ def device(self) -> torch.device:
@property
def dtype(self) -> torch.dtype:
return self.objective.dtype


def _forward(objective, optimizer, optimizer_kwargs, input_data):
objective.update(input_data)
info = optimizer.optimize(**optimizer_kwargs)
vars = [var.data for var in objective.optim_vars.values()]
return vars, info


class TheseusLayerDLMForward(torch.autograd.Function):
"""
Functionally the same as the forward method in a TheseusLayer
but computes the direct loss minimization in the backward pass.
"""

_dlm_epsilon = "dlm_epsilon"
_grad_suffix = "_grad"

@staticmethod
def forward(
ctx,
objective,
optimizer,
optimizer_kwargs,
bwd_objective,
bwd_optimizer,
epsilon,
n,
*input_data,
):
input_keys = input_data[:n]
input_vals = input_data[n : 2 * n]
differentiable_tensors = input_data[2 * n :]
ctx.n = n
ctx.k = len(differentiable_tensors)

input_data = dict(zip(input_keys, input_vals))
ctx.input_keys = input_keys

optim_tensors, info = _forward(
objective, optimizer, optimizer_kwargs, input_data
)

# Skip computation if there are no differentiable inputs.
if ctx.k > 0:
ctx.bwd_objective = bwd_objective
ctx.bwd_optimizer = bwd_optimizer
ctx.epsilon = epsilon

# Precompute and cache this.
with torch.enable_grad():
grad_sol = torch.autograd.grad(
objective.error_squared_norm().sum(),
differentiable_tensors,
allow_unused=True,
)
ctx.save_for_backward(
*input_vals, *grad_sol, *differentiable_tensors, *optim_tensors
)
return (*optim_tensors, info)

@staticmethod
@once_differentiable
def backward(ctx, *grad_outputs):
n, k = ctx.n, ctx.k
saved_tensors = ctx.saved_tensors
input_vals = saved_tensors[:n]
grad_sol = saved_tensors[n : n + k]
differentiable_tensors = saved_tensors[n + k : n + k + k]
optim_tensors = saved_tensors[n + k + k :]
grad_outputs = grad_outputs[:-1]

bwd_objective = ctx.bwd_objective
bwd_optimizer = ctx.bwd_optimizer
epsilon = ctx.epsilon
input_keys = ctx.input_keys

# Update the optim vars to their solutions.
bwd_data = dict(zip(input_keys, input_vals))
for k, v in zip(bwd_objective.optim_vars.keys(), optim_tensors):
bwd_data[k] = v.detach()

# Add in gradient values.
grad_data = {
TheseusLayerDLMForward._dlm_epsilon: torch.tensor(epsilon)
.to(grad_outputs[0])
.reshape(1, 1)
}
for i, name in enumerate(bwd_objective.optim_vars.keys()):
grad_data[name + TheseusLayerDLMForward._grad_suffix] = grad_outputs[i]
bwd_data.update(grad_data)

# Solve backward objective.
bwd_objective.update(bwd_data)
with torch.no_grad():
bwd_optimizer.linear_solver.linearization.linearize()
delta = bwd_optimizer.linear_solver.solve()
bwd_optimizer.retract_and_update_variables(
delta, None, 1.0, force_update=True
)

# Compute gradients.
with torch.enable_grad():
grad_perturbed = torch.autograd.grad(
bwd_objective.error_squared_norm().sum(),
differentiable_tensors,
allow_unused=True,
)

nones = [None] * (ctx.n * 2)
grads = [
(gs - gp) / epsilon if gs is not None else None
for gs, gp in zip(grad_sol, grad_perturbed)
]
return (None, None, None, None, None, None, None, *nones, *grads)


def _dlm_perturbation(optim_vars, aux_vars):
v = optim_vars[0]
g = aux_vars[0]
epsilon = aux_vars[1]
return epsilon.data * v.data - 0.5 * g.data


def _instantiate_dlm_bwd_objective(objective):
bwd_objective = objective.copy()
epsilon_var = Variable(torch.ones(1, 1), name=TheseusLayerDLMForward._dlm_epsilon)
for name, var in bwd_objective.optim_vars.items():
grad_var = Variable(
torch.zeros_like(var.data), name=name + TheseusLayerDLMForward._grad_suffix
)
bwd_objective.add(
AutoDiffCostFunction(
[var],
_dlm_perturbation,
var.shape[1],
aux_vars=[grad_var, epsilon_var],
name="dlm_perturbation_" + name,
)
)

bwd_optimizer = GaussNewton(
bwd_objective,
max_iterations=1,
step_size=1.0,
)
return bwd_objective, bwd_optimizer

0 comments on commit 9e11ced

Please sign in to comment.