Skip to content

Commit

Permalink
Added utility to automatically check jacobians of a given cost functi…
Browse files Browse the repository at this point in the history
…on. (#465)

* Added utility to automatically check jacobians of a given cost function.

* Added unit test for jacobians_check.

* Added TheseusLayer.verify_jacobians() method.

* Added verify_jacobians() call to test_theseus_layer.
  • Loading branch information
luisenp authored Mar 8, 2023
1 parent 16eb5b2 commit c3300ff
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,15 @@ def _run_optimizer_test(
force_vectorization=False,
max_iterations=10,
lr=0.075,
loss_ratio_target=0.01,
):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"_run_test_for: {device}")
print(
f"testing for optimizer {nonlinear_optimizer_cls.__name__}, "
f"cost weight modeled as {cost_weight_model}, "
f"linear solver {linear_solver_cls.__name__}"
f"linear solver {linear_solver_cls.__name__} "
f"learning method {learning_method}"
)

rng = torch.Generator(device=device)
Expand Down Expand Up @@ -286,6 +288,7 @@ def cost_weight_fn():
max_iterations=max_iterations,
)
layer_to_learn.to(device)
layer_to_learn.verify_jacobians()

# Check the initial solution quality to check how much has loss improved later

Expand Down Expand Up @@ -372,10 +375,11 @@ def cost_weight_fn():
loss = mse_loss

loss.backward()
print("Loss: ", loss.item())
optimizer.step()

if mse_loss.item() / loss0 < 1e-2:
loss_ratio = mse_loss.item() / loss0
print("Loss: ", mse_loss.item(), ". Loss ratio: ", loss_ratio)
if loss_ratio < loss_ratio_target:
solved = True
break
assert solved
Expand Down
14 changes: 14 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn

import theseus as th
import theseus.utils as thutils


Expand Down Expand Up @@ -154,3 +155,16 @@ def test_sparse_mv_cpu(batch_size, num_rows, num_cols, fill):
@pytest.mark.parametrize("fill", [0.1, 0.9])
def test_sparse_mv_cuda(batch_size, num_rows, num_cols, fill):
_check_sparse_mv_and_mtv(batch_size, num_rows, num_cols, fill, "cuda:0")


def test_jacobians_check():
se3s = [th.SE3() for _ in range(3)]
w = th.ScaleCostWeight(0.5)
cf = th.Difference(se3s[0], se3s[1], w)
thutils.check_jacobians(cf, 1)

cf = th.Between(se3s[0], se3s[1], se3s[2], w)
thutils.check_jacobians(cf, 1)

cf = th.eb.DoubleIntegrator(se3s[0], th.Vector(6), se3s[1], th.Vector(6), 1.0, w)
thutils.check_jacobians(cf, 1)
13 changes: 13 additions & 0 deletions theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from theseus.optimizer import Optimizer, OptimizerInfo
from theseus.optimizer.linear import LinearSolver
from theseus.optimizer.nonlinear import BackwardMode, GaussNewton
from theseus.utils import check_jacobians


class TheseusLayer(nn.Module):
Expand Down Expand Up @@ -146,6 +147,18 @@ def device(self) -> DeviceType:
def dtype(self) -> torch.dtype:
return self.objective.dtype

def verify_jacobians(self, num_checks: int = 1, tol: float = 1.0e-3):
success = True
for cf in self.objective.cost_functions.values():
try:
check_jacobians(cf, num_checks=num_checks, tol=tol)
except RuntimeError as e:
print(f"Jacobians check for cost function named {cf.name} failed.")
print(e)
success = False
if success:
print("Jacobians check were successful!")


def _forward(
objective: Objective,
Expand Down
1 change: 1 addition & 0 deletions theseus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .utils import (
Timer,
build_mlp,
check_jacobians,
gather_from_rows_cols,
numeric_grad,
numeric_jacobian,
Expand Down
55 changes: 55 additions & 0 deletions theseus/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,61 @@ def df(x: np.ndarray):
return df


# Updates the given variable with a random tensor of the same shape as the original.
def _rand_fill_(v: th.Variable, batch_size: int):
if isinstance(v, (th.SE2, th.SO3, th.SE3, th.SO3)):
v.update(v.rand(batch_size, dtype=v.dtype, device=v.device).tensor)
else:
v.update(
torch.rand((batch_size,) + v.shape[1:], dtype=v.dtype, device=v.device)
)


# Automatically checks the jacobians of the given cost function a number of times.
#
# Computes the manifold jacobians of the given cost function with respect to all
# optimization variables, evaluated at randomly sampled values
# of the optimization and auxiliary variable, and compares them with the corresponding
# ones computed by torch autograd. By default, only checks once, but more checks can
# be specified, with one set of sampled variables per each. The jacobians are
# compared using the infinity norm of the jacobian matrix, at the specified tolerance.
@torch.no_grad()
def check_jacobians(cf: th.CostFunction, num_checks: int = 1, tol: float = 1.0e-3):
from theseus.core.cost_function import _tmp_tensors

optim_vars: List[th.Manifold] = list(cf.optim_vars)
aux_vars = list(cf.aux_vars)

def autograd_fn(*optim_var_tensors):
for v, t in zip(optim_vars, optim_var_tensors):
v.update(t)
return cf.error()

with _tmp_tensors(optim_vars), _tmp_tensors(aux_vars):
for _ in range(num_checks):
for v in optim_vars + aux_vars:
_rand_fill_(v, 1)

autograd_jac = torch.autograd.functional.jacobian(
autograd_fn, tuple(v.tensor for v in optim_vars)
)
jac, _ = cf.jacobians()
for idx, v in enumerate(optim_vars):
j1 = jac[idx][0]
j2 = autograd_jac[idx]
# In some "unfriendly" cost functions, the error's batch size could
# be different than the optim/aux vars batch size, if they save
# tensors that are not exposed as Theseus variables. To avoid issues,
# we just check the first element of the batch.
j2_sparse = j2[0, :, 0, :]
j2_sparse_manifold = v.project(j2_sparse, is_sparse=True)
if (j1 - j2_sparse_manifold).abs().max() > tol:
raise RuntimeError(
f"Jacobian for variable {v.name} appears incorrect to the "
"given tolerance."
)


# A basic timer utility that adapts to the device. Useful for removing
# boilerplate code when benchmarking tasks.
# For CPU it uses time.perf_counter_ns()
Expand Down

0 comments on commit c3300ff

Please sign in to comment.