Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DLM gradients #161

Merged
merged 23 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion examples/backward_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def quad_error_fn(optim_vars, aux_vars):
optimizer = th.GaussNewton(
objective,
max_iterations=15,
step_size=0.5,
step_size=1.0,
)

theseus_inputs = {
Expand Down Expand Up @@ -128,6 +128,22 @@ def quad_error_fn(optim_vars, aux_vars):
print(da_dx.numpy())


# We can also compute the direct loss minimization gradient.
updated_inputs, info = theseus_optim.forward(
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
theseus_inputs,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.DLM,
"DLM_epsilon": 1e-3,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
print("\n--- backward_mode=DLM")
print(da_dx.numpy())


# Next we numerically check the derivative
def fit_x(data_x_np):
theseus_inputs["x"] = (
Expand Down Expand Up @@ -199,6 +215,21 @@ def fit_x(data_x_np):
].squeeze()
times["bwd_trunc"].append(time.time() - start)

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.DLM,
"DLM_epsilon": 1e-2,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
times["bwd_dlm"].append(time.time() - start)


print("\n=== Runtimes")
k = "fwd"
Expand All @@ -214,3 +245,6 @@ def fit_x(data_x_np):
print(
f"Backward (TRUNCATED, 5 steps) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s"
)

k = "bwd_dlm"
print(f"Backward (DLM) {np.mean(times[k]):.2e} s +/- {np.std(times[k]):.2e} s")
3 changes: 2 additions & 1 deletion theseus/optimizer/nonlinear/nonlinear_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BackwardMode(Enum):
FULL = 0
IMPLICIT = 1
TRUNCATED = 2
DLM = 3


class NonlinearOptimizer(Optimizer, abc.ABC):
Expand Down Expand Up @@ -307,7 +308,7 @@ def _optimize_impl(
f"Error: {info.last_err.mean().item()}"
)

if backward_mode == BackwardMode.FULL:
if backward_mode in [BackwardMode.FULL, BackwardMode.DLM]:
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
info = self._optimize_loop(
start_iter=0,
num_iter=self.params.max_iterations,
Expand Down
111 changes: 103 additions & 8 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@

from typing import Any, Dict, Optional, Tuple

from functools import partial

import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable

from theseus import Variable, GaussNewton
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
from theseus.core.cost_function import AutoDiffCostFunction
from theseus.optimizer import Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer.nonlinear import BackwardMode


class TheseusLayer(nn.Module):
Expand All @@ -32,15 +38,19 @@ def forward(
"The objective was modified after the layer's construction, which is "
"currently not supported."
)
self.objective.update(input_data)
optimizer_kwargs = optimizer_kwargs or {}
info = self.optimizer.optimize(**optimizer_kwargs)
values = dict(
[
(var_name, var.data)
for var_name, var in self.objective.optim_vars.items()
]
)
backward_mode = optimizer_kwargs.get("backward_mode", None)
DLM_epsilon = optimizer_kwargs.get("DLM_epsilon", 1e-2)
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
if backward_mode == BackwardMode.DLM:
# TODO: instantiate self.bwd_objective here.
names = set(self.objective.aux_vars.keys()).intersection(input_data.keys())
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
tensors = [input_data[n] for n in names]
*vars, info = TheseusLayerDLMForward.apply(
self.objective, self.optimizer, optimizer_kwargs, input_data, DLM_epsilon, *tensors
)
else:
vars, info = _forward(self.objective, self.optimizer, optimizer_kwargs, input_data)
values = dict(zip(self.objective.optim_vars.keys(), vars))
return values, info

def compute_samples(
Expand Down Expand Up @@ -93,3 +103,88 @@ def device(self) -> torch.device:
@property
def dtype(self) -> torch.dtype:
return self.objective.dtype


def _forward(objective, optimizer, optimizer_kwargs, input_data):
objective.update(input_data)
info = optimizer.optimize(**optimizer_kwargs)
vars = [var.data for var in objective.optim_vars.values()]
return vars, info


class TheseusLayerDLMForward(torch.autograd.Function):
mhmukadam marked this conversation as resolved.
Show resolved Hide resolved
"""
Functionally the same as the forward method in a TheseusLayer
but computes the direct loss minimization in the backward pass.
"""

@staticmethod
def forward(ctx, objective, optimizer, optimizer_kwargs, input_data, epsilon, *params):
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
optim_vars, info = _forward(objective, optimizer, optimizer_kwargs, input_data)
rtqichen marked this conversation as resolved.
Show resolved Hide resolved

ctx.input_data = input_data.copy()
ctx.objective = objective
ctx.epsilon = epsilon

# Ideally we compute this in the backward function, but if we try to do that,
# it ends up in an infinite loop because it depends on the outputs of this function.
with torch.enable_grad():
grad_sol = torch.autograd.grad(objective.error_squared_norm().sum(), params, retain_graph=True)

ctx.save_for_backward(*params, *grad_sol, *optim_vars)
ctx.n_params = len(params)
return (*optim_vars, info)


@staticmethod
@once_differentiable
def backward(ctx, *grad_outputs):
saved_tensors = ctx.saved_tensors
params = saved_tensors[:ctx.n_params]
grad_sol = saved_tensors[ctx.n_params:2 * ctx.n_params]
optim_vars = saved_tensors[2 * ctx.n_params:]
grad_outputs = grad_outputs[:-1]

objective = ctx.objective
epsilon = ctx.epsilon

# Update the optim vars to their solutions.
input_data = ctx.input_data
values = dict(zip(objective.optim_vars.keys(), optim_vars))
input_data.update(values)

# Construct backward objective.
bwd_objective = objective.copy()
rtqichen marked this conversation as resolved.
Show resolved Hide resolved

# Can we put all of this into a single cost function?
rtqichen marked this conversation as resolved.
Show resolved Hide resolved
for i, (name, var) in enumerate(bwd_objective.optim_vars.items()):
grad_var = Variable(grad_outputs[i], name=name + "_grad")
bwd_objective.add(AutoDiffCostFunction(
[var],
partial(_dlm_perturbation, epsilon=epsilon),
1,
aux_vars=[grad_var],
name="DLM_perturbation_" + name,
))

# Solve backward objective.
bwd_objective.update(input_data)
bwd_optimizer = GaussNewton(
bwd_objective,
max_iterations=1,
step_size=1.0,
)
bwd_optimizer.optimize()

# Compute gradients.
with torch.enable_grad():
grad_perturbed = torch.autograd.grad(bwd_objective.error_squared_norm().sum(), params, retain_graph=True)

grads = [(gs - gp) / epsilon for gs, gp in zip(grad_sol, grad_perturbed)]
return (None, None, None, None, None, *grads)


def _dlm_perturbation(optim_vars, aux_vars, epsilon):
v = optim_vars[0]
g = aux_vars[0]
return epsilon * v.data - 0.5 * g.data