Skip to content

Commit

Permalink
Adds support for energy based learning with NLL loss (LEO) (facebookr…
Browse files Browse the repository at this point in the history
…esearch#30)

* add tests for leo with GN/LM optimizers
* add sampler to GN/LM optimizers
* run leo on 2d state estimation, add viz, learning_method options
  • Loading branch information
psodhi authored Dec 27, 2021
1 parent 703965c commit a6f847d
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 13 deletions.
131 changes: 121 additions & 10 deletions examples/state_estimation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,51 @@
import torch.nn.functional as F

import theseus as th
import matplotlib.pyplot as plt

device = "cpu"
torch.manual_seed(0)
path_length = 50
state_size = 2
batch_size = 4
learning_method = "leo" # "default", "leo"

vis_flag = True
plt.ion()


# --------------------------------------------------- #
# --------------------- Utilities ------------------- #
# --------------------------------------------------- #
def plot_path(optimizer_path, groundtruth_path):
plt.cla()
plt.gca().axis("equal")

plt.xlim(-250, 250)
plt.ylim(-100, 400)

batch_idx = 0
plt.plot(
optimizer_path[batch_idx, :, 0],
optimizer_path[batch_idx, :, 1],
linewidth=2,
linestyle="-",
color="tab:orange",
label="optimizer",
)
plt.plot(
groundtruth_path[batch_idx, :, 0],
groundtruth_path[batch_idx, :, 1],
linewidth=2,
linestyle="-",
color="tab:green",
label="groundtruth",
)

plt.show()
plt.pause(1e-12)


def generate_path_data(
batch_size_,
num_measurements_,
Expand Down Expand Up @@ -120,6 +154,42 @@ def get_path_from_values(batch_size_, values_, path_length_):
return path


def get_values_from_path(path_):
"""
:param path_: tensor of dim batch_size_ x path_length_ x 2
:return: values: dict of (x,y) pos values
"""
[batch_size_, path_length_, dim] = path_.shape
values = {}
for i in range(path_length_):
values[f"pose_{i}"] = path_[:, i, :2]
return values


def get_average_sample_cost(x_samples, cost_weights_model, objective, mode_):
cost_opt = None
n_samples = x_samples.shape[-1]
for sidx in range(0, n_samples):
x_sample_vals = get_values_from_path(
x_samples[:, :, sidx].reshape(x_samples.shape[0], -1, 2)
)
theseus_inputs = run_model(
mode_,
cost_weights_model,
x_sample_vals,
path_length,
print_stuff=False,
)
objective.update(theseus_inputs)
if cost_opt is not None:
cost_opt = cost_opt + torch.sum(objective.error(), dim=1)
else:
cost_opt = torch.sum(objective.error(), dim=1)
cost_opt = cost_opt / n_samples

return cost_opt


# ------------------------------------------------------------- #
# --------------------------- Learning ------------------------ #
# ------------------------------------------------------------- #
Expand All @@ -132,7 +202,7 @@ def run_learning(mode_, path_data_, gps_targets_, measurements_):
def cost_weights_model():
return model_params * torch.ones(1)

model_optimizer = torch.optim.Adam([model_params], lr=3e-2)
model_optimizer = torch.optim.Adam([model_params], lr=5e-2)
else:
cost_weights_model = SimpleNN(state_size, 2, hid_size=100, use_offset=False).to(
device
Expand Down Expand Up @@ -201,14 +271,14 @@ def cost_weights_model():
state_estimator.to(device)

# ## Learning loop
path_tensor = torch.stack(path_data_).permute(1, 0, 2)
best_loss = 1000.0
inner_loop_iters = 3
groundtruth_path = torch.stack(path_data_).permute(1, 0, 2)
best_solution = None
losses = []
for epoch in range(200):
for epoch in range(500):
model_optimizer.zero_grad()

inner_loop_iters = 3
theseus_inputs = get_initial_inputs(gps_targets_)
theseus_inputs = run_model(
mode_,
Expand Down Expand Up @@ -236,21 +306,64 @@ def cost_weights_model():
print_stuff=epoch % 10 == 0 and i == 0,
)

solution_path = get_path_from_values(
optimizer_path = get_path_from_values(
objective.batch_size, theseus_inputs, path_length
)
mse_loss = F.mse_loss(optimizer_path, groundtruth_path)

# LEO (Sodhi et al., https://arxiv.org/abs/2108.02274) is a method to learn
# models end-to-end within second-order optimizers. The main difference is that
# instead of unrolling the optimizer and minimizing the MSE tracking loss,
# it uses a NLL energy-based loss that does not backpropagate through the optimizer.
if learning_method == "leo":
x_samples = state_estimator.compute_samples(
optimizer.linear_solver, n_samples=10, temperature=1.0
) # batch_size x n_vars x n_samples
# When x_samples is None, this defaults to a perceptron loss
# using the mean trajectory solution from the optimizer.
if x_samples is None:
x_opt_dict = {key: val.detach() for key, val in theseus_inputs.items()}
x_samples = get_path_from_values(
objective.batch_size, x_opt_dict, path_length
)
x_samples = x_samples.reshape(x_samples.shape[0], -1).unsqueeze(
-1
) # batch_size x n_vars x 1
cost_opt = get_average_sample_cost(
x_samples, cost_weights_model, objective, mode_
)
x_gt = get_values_from_path(groundtruth_path)
theseus_inputs_gt = run_model(
mode_,
cost_weights_model,
x_gt,
path_length,
print_stuff=False,
)
objective.update(theseus_inputs_gt)
cost_gt = torch.sum(objective.error(), dim=1)
loss = cost_gt - cost_opt
else:
loss = mse_loss

loss = F.mse_loss(solution_path, path_tensor)
loss = torch.mean(loss, dim=0)
loss.backward()
model_optimizer.step()

loss_value = loss.item()
losses.append(loss_value)
if loss_value < best_loss:
best_loss = loss_value
best_solution = solution_path.detach()
best_solution = optimizer_path.detach()

if epoch % 10 == 0:
print("TOTAL LOSS: ", loss.item())
if vis_flag:
plot_path(
optimizer_path.detach().cpu().numpy(),
groundtruth_path.detach().cpu().numpy(),
)
print("Loss: ", loss.item())
print("MSE error: ", mse_loss.item())
print(f" ---------------- END EPOCH {epoch} -------------- ")

return best_solution, losses
Expand All @@ -269,8 +382,6 @@ def cost_weights_model():
measurement_noise = 0.005 * torch.randn(batch_size, 2).view(batch_size, 2)
measurements.append(measurement + measurement_noise)

mlp_solution, mlp_losses = run_learning("mlp", path_data, gps_targets, measurements)
print(" -------------------------------------------------------------- ")
constant_solution, constant_losses = run_learning(
"constant", path_data, gps_targets, measurements
)
3 changes: 3 additions & 0 deletions theseus/optimizer/dense_linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def _linearize_hessian_impl(self):
At = self.A.transpose(1, 2)
self.AtA = At.bmm(self.A)
self.Atb = At.bmm(self.b.unsqueeze(2))

def hessian_approx(self):
return self.AtA
3 changes: 3 additions & 0 deletions theseus/optimizer/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def linearize(self):
"Attempted to linearize an objective with an incomplete variable order."
)
self._linearize_hessian_impl()

def hessian_approx(self):
raise NotImplementedError
89 changes: 86 additions & 3 deletions theseus/tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,26 @@ def error_fn(optim_vars, aux_vars):
return theseus_layer


def get_average_sample_cost(
x_samples, layer_to_learn, cost_weight_param_name, cost_weight_fn
):
cost_opt = None
n_samples = x_samples.shape[-1]
for sidx in range(0, n_samples):
input_values_opt = {
"coefficients": x_samples[:, :, sidx],
cost_weight_param_name: cost_weight_fn(),
}
layer_to_learn.objective.update(input_values_opt)
if cost_opt is not None:
cost_opt = cost_opt + torch.sum(layer_to_learn.objective.error(), dim=1)
else:
cost_opt = torch.sum(layer_to_learn.objective.error(), dim=1)
cost_opt = cost_opt / n_samples

return cost_opt


def test_layer_solver_constructor():
dummy = torch.ones(1, 1)
for linear_solver_cls in [th.LUDenseSolver, th.CholeskyDenseSolver]:
Expand All @@ -154,6 +174,7 @@ def _run_optimizer_test(
cost_weight_model,
use_learnable_error=False,
verbose=True,
learning_method="default",
):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"_run_test_for: {device}")
Expand Down Expand Up @@ -280,11 +301,45 @@ def cost_weight_fn():
| (info.status == th.NonlinearOptimizerStatus.FAIL)
).all()

loss = F.mse_loss(pred_vars["coefficients"], target_vars["coefficients"])
mse_loss = F.mse_loss(pred_vars["coefficients"], target_vars["coefficients"])

if learning_method == "leo":
# groundtruth cost
x_gt = target_vars["coefficients"]
input_values_gt = {
"coefficients": x_gt,
cost_weight_param_name: cost_weight_fn(),
}
layer_to_learn.objective.update(input_values_gt)
cost_gt = torch.sum(layer_to_learn.objective.error(), dim=1)

# optimizer cost
x_opt = pred_vars["coefficients"].detach()
x_samples = layer_to_learn.compute_samples(
layer_to_learn.optimizer.linear_solver, n_samples=10, temperature=1.0
) # batch_size x n_vars x n_samples
if x_samples is None: # use mean solution
x_samples = x_opt.reshape(x_opt.shape[0], -1).unsqueeze(
-1
) # batch_size x n_vars x n_samples
cost_opt = get_average_sample_cost(
x_samples, layer_to_learn, cost_weight_param_name, cost_weight_fn
)

# loss value
l2_reg = F.mse_loss(
cost_weight_fn(), torch.zeros((1, num_points), device=device)
)
loss = (cost_gt - cost_opt) + 10.0 * l2_reg
loss = torch.mean(loss, dim=0)
else:
loss = mse_loss

loss.backward()
print(i, loss.item(), loss.item() / loss0)
optimizer.step()
if loss.item() / loss0 < 5e-3:

print(i, mse_loss.item())
if mse_loss.item() / loss0 < 5e-3:
solved = True
break
assert solved
Expand Down Expand Up @@ -340,6 +395,34 @@ def test_backward_levenberg_marquardt_choleskysparse():
)


def test_backward_gauss_newton_leo():
for use_learnable_error in [True, False]:
for linear_solver_cls in [th.CholeskyDenseSolver, th.LUDenseSolver]:
for cost_weight_model in ["mlp"]:
_run_optimizer_test(
th.GaussNewton,
linear_solver_cls,
{},
cost_weight_model,
use_learnable_error=use_learnable_error,
learning_method="leo",
)


def test_backward_levenberg_marquardt_leo():
for use_learnable_error in [True, False]:
for linear_solver_cls in [th.CholeskyDenseSolver, th.LUDenseSolver]:
for cost_weight_model in ["mlp"]:
_run_optimizer_test(
th.LevenbergMarquardt,
linear_solver_cls,
{"damping": 0.01},
cost_weight_model,
use_learnable_error=use_learnable_error,
learning_method="leo",
)


def test_send_to_device():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"test_send_to_device: {device}")
Expand Down
39 changes: 39 additions & 0 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

from theseus.optimizer import Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver


class TheseusLayer(nn.Module):
Expand Down Expand Up @@ -45,6 +46,44 @@ def forward(
)
return values, info

def compute_samples(
self,
linear_solver: LinearSolver = None,
n_samples: int = 10,
temperature: float = 1.0,
) -> torch.Tensor:
# When samples are not available, return None. This makes the outer learning loop default
# to a perceptron loss using the mean trajectory solution from the optimizer.
if linear_solver is None:
return None

# Sampling from multivariate normal using a Cholesky decomposition of AtA,
# http://www.statsathome.com/2018/10/19/sampling-from-multivariate-normal-precision-and-covariance-parameterizations/
delta = linear_solver.solve()
AtA = linear_solver.linearization.hessian_approx() / temperature
sqrt_AtA = torch.linalg.cholesky(AtA).permute(0, 2, 1)

batch_size, n_vars = delta.shape
y = torch.normal(
mean=torch.zeros((n_vars, n_samples), device=delta.device),
std=torch.ones((n_vars, n_samples), device=delta.device),
)
delta_samples = (torch.triangular_solve(y, sqrt_AtA).solution) + (
delta.unsqueeze(-1)
).repeat(1, 1, n_samples)

x_samples = torch.zeros((batch_size, n_vars, n_samples), device=delta.device)
for sidx in range(0, n_samples):
var_idx = 0
for var in linear_solver.linearization.ordering:
new_var = var.retract(
delta_samples[:, var_idx : var_idx + var.dof(), sidx]
)
x_samples[:, var_idx : var_idx + var.dof(), sidx] = new_var.data
var_idx = var_idx + var.dof()

return x_samples

# Applies to() with given args to all tensors in the objective
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
Expand Down

0 comments on commit a6f847d

Please sign in to comment.