Skip to content

Commit

Permalink
Changed TheseusLayer.forward() to receive optimizer_kwargs as a singl…
Browse files Browse the repository at this point in the history
…e dict (#45)

* [refactor] Changed TheseusLayer so that optimizer_kwargs are passed as a single dict.

* Updated all tutorials to use optimizer_kwargs dict in forward().

* Updated examples to use optimizer_kwargs dict in forward().

* Add additional test to check that TheseusLayer.forward(aux_vars=) is not accepted.
  • Loading branch information
luisenp authored Jan 24, 2022
1 parent 6d89db7 commit 58d3c6e
Show file tree
Hide file tree
Showing 15 changed files with 223 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [torch==1.9.0, tokenize-rt==3.2.0, types-PyYAML]
additional_dependencies: [torch==1.9.0, tokenize-rt==3.2.0, types-PyYAML, types-mock]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py

Expand Down
56 changes: 34 additions & 22 deletions examples/backward_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def quad_error_fn(optim_vars, aux_vars):
theseus_optim = th.TheseusLayer(optimizer)
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)

# The quadratic \hat y is now fit and we can also use Theseus
Expand All @@ -98,9 +100,11 @@ def quad_error_fn(optim_vars, aux_vars):
# forward again and changing the backward_mode flag.
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
Expand All @@ -110,10 +114,12 @@ def quad_error_fn(optim_vars, aux_vars):
# We can also use truncated unrolling to compute the derivative:
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)

da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[0].squeeze()
Expand All @@ -127,8 +133,8 @@ def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
)
updated_inputs, info = theseus_optim.forward(
theseus_inputs, track_best_solution=True, verbose=False
updated_inputs, _ = theseus_optim.forward(
theseus_inputs, optimizer_kwargs={"track_best_solution": True, "verbose": False}
)
return updated_inputs["a"].item()

Expand All @@ -150,9 +156,11 @@ def fit_x(data_x_np):
start = time.time()
updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)
times["fwd"].append(time.time() - start)

Expand All @@ -164,9 +172,11 @@ def fit_x(data_x_np):

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
Expand All @@ -176,10 +186,12 @@ def fit_x(data_x_np):

updated_inputs, info = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)
start = time.time()
da_dx = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
Expand Down
10 changes: 7 additions & 3 deletions examples/motion_planning_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def run_learning_loop(cfg):

_, info = motion_planner.layer.forward(
planner_inputs,
track_best_solution=True,
verbose=cfg.verbose,
**cfg.optim_params.kwargs,
optimizer_kwargs={
**{
"track_best_solution": True,
"verbose": cfg.verbose,
},
**cfg.optim_params.kwargs,
},
)
if cfg.do_learning and cfg.include_imitation_loss:
solution_trajectory = motion_planner.get_trajectory()
Expand Down
8 changes: 5 additions & 3 deletions examples/state_estimation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,12 @@ def cost_weights_model():
print("Initial error:", objective.error_squared_norm().mean().item())

for i in range(inner_loop_iters):
theseus_inputs, info = state_estimator.forward(
theseus_inputs, _ = state_estimator.forward(
theseus_inputs,
track_best_solution=True,
verbose=epoch % 10 == 0,
optimizer_kwargs={
"track_best_solution": True,
"verbose": epoch % 10 == 0,
},
)
theseus_inputs = run_model(
mode_,
Expand Down
4 changes: 3 additions & 1 deletion examples/tactile_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def run_learning_loop(cfg):
(sdf_tensor.data).repeat(batch_size, 1, 1).to(device)
)

theseus_inputs, _ = theseus_layer.forward(theseus_inputs, verbose=True)
theseus_inputs, _ = theseus_layer.forward(
theseus_inputs, optimizer_kwargs={"verbose": True}
)

obj_poses_opt = theg.get_tactile_poses_from_values(
batch_size=batch_size,
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ nox==2020.8.22
pre-commit>=2.9.2
isort>=5.6.4
types-PyYAML==5.4.3
types-mock>=4.0.8
3 changes: 2 additions & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ scikit-sparse>=0.4.5
# torch>=1.7.1 will do separate install instructions for now (CUDA dependent)
pytest>=6.2.1
numdifftools>=0.9.40
pybind11>=2.7.1
pybind11>=2.7.1
mock>=4.0.3
1 change: 1 addition & 0 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int:
return max_bs
raise ValueError("Provided data tensors must be broadcastable.")

input_data = input_data or {}
for var_name, data in input_data.items():
if data.ndim < 2:
raise ValueError(
Expand Down
40 changes: 22 additions & 18 deletions theseus/optimizer/nonlinear/tests/test_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def fit_x(data_x_np):
theseus_inputs["x"] = (
torch.from_numpy(data_x_np).float().clone().requires_grad_().unsqueeze(0)
)
updated_inputs, info = theseus_optim.forward(
theseus_inputs, track_best_solution=True, verbose=False
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
optimizer_kwargs={"track_best_solution": True, "verbose": False},
)
return updated_inputs["a"].item()

Expand All @@ -79,39 +80,42 @@ def fit_x(data_x_np):
da_dx_numeric = torch.from_numpy(dfit_x(data_x_np)).float()

theseus_inputs["x"] = data_x
updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.FULL,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.FULL,
},
)
da_dx_full = torch.autograd.grad(updated_inputs["a"], data_x, retain_graph=True)[
0
].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_full, atol=1e-3)

updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.IMPLICIT,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.IMPLICIT,
},
)
da_dx_implicit = torch.autograd.grad(
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_implicit, atol=1e-4)

updated_inputs, info = theseus_optim.forward(
updated_inputs, _ = theseus_optim.forward(
theseus_inputs,
track_best_solution=True,
verbose=False,
backward_mode=th.BackwardMode.TRUNCATED,
backward_num_iterations=5,
optimizer_kwargs={
"track_best_solution": True,
"verbose": False,
"backward_mode": th.BackwardMode.TRUNCATED,
"backward_num_iterations": 5,
},
)
da_dx_truncated = torch.autograd.grad(
updated_inputs["a"], data_x, retain_graph=True
)[0].squeeze()
assert torch.allclose(da_dx_numeric, da_dx_truncated, atol=1e-4)


test_backwards()
91 changes: 84 additions & 7 deletions theseus/tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import math

import mock
import pytest # noqa: F401
import torch
import torch.nn as nn
Expand Down Expand Up @@ -214,7 +215,7 @@ def _run_optimizer_test(
with torch.no_grad():
input_values = {"coefficients": torch.ones(batch_size, 2, device=device) * 0.75}
target_vars, _ = layer_ref.forward(
input_values, verbose=verbose, **optimizer_kwargs
input_values, optimizer_kwargs={**optimizer_kwargs, **{"verbose": verbose}}
)

# Now create another that starts with a random cost weight and use backpropagation to
Expand Down Expand Up @@ -275,7 +276,9 @@ def cost_weight_fn():
}

with torch.no_grad():
pred_vars, info = layer_to_learn.forward(input_values, **optimizer_kwargs)
pred_vars, info = layer_to_learn.forward(
input_values, optimizer_kwargs=optimizer_kwargs
)
loss0 = F.mse_loss(
pred_vars["coefficients"], target_vars["coefficients"]
).item()
Expand All @@ -294,7 +297,7 @@ def cost_weight_fn():
cost_weight_param_name: cost_weight_fn(),
}
pred_vars, info = layer_to_learn.forward(
input_values, verbose=verbose, **optimizer_kwargs
input_values, optimizer_kwargs={**optimizer_kwargs, **{"verbose": verbose}}
)
assert not (
(info.status == th.NonlinearOptimizerStatus.START)
Expand Down Expand Up @@ -433,14 +436,14 @@ def test_send_to_device():
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

objective = create_qf_theseus_layer(xs, ys)
layer = create_qf_theseus_layer(xs, ys)
input_values = {"coefficients": torch.ones(batch_size, 2, device=device) * 0.5}
with torch.no_grad():
if device != "cpu":
with pytest.raises(RuntimeError):
objective.forward(input_values)
objective.to(device)
output_values, _ = objective.forward(input_values)
layer.forward(input_values)
layer.to(device)
output_values, _ = layer.forward(input_values)
for k, v in output_values.items():
assert v.device == input_values[k].device

Expand Down Expand Up @@ -470,3 +473,77 @@ def _do_check(layer_, optimizer_):
optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver)
objective.erase(cost_functions[0].name)
_do_check(layer, optimizer)


def test_pass_optimizer_kwargs():
# Create the dataset to fit, model(x) is the true data generation process
batch_size = 16
num_points = 10
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

layer = create_qf_theseus_layer(
xs,
ys,
nonlinear_optimizer_cls=th.GaussNewton,
linear_solver_cls=th.CholmodSparseSolver,
)
layer.to("cpu")
input_values = {"coefficients": torch.ones(batch_size, 2) * 0.5}
for tbs in [True, False]:
_, info = layer.forward(
input_values, optimizer_kwargs={"track_best_solution": tbs}
)
if tbs:
assert (
isinstance(info.best_solution, dict)
and "coefficients" in info.best_solution
)
else:
assert info.best_solution is None

# Pass invalid backward mode to trigger exception
with pytest.raises(ValueError):
layer.forward(input_values, optimizer_kwargs={"backward_mode": -1})

# Now test that compute_delta() args passed correctly
# Path compute_delta() to receive args we control
def _mock_compute_delta(cls, fake_arg=None, **kwargs):
if fake_arg is not None:
raise ValueError
return layer.optimizer.linear_solver.solve()

with mock.patch.object(th.GaussNewton, "compute_delta", _mock_compute_delta):
layer_2 = create_qf_theseus_layer(xs, ys)
layer_2.forward(input_values)
# If fake_arg is passed correctly, the mock of compute_delta will trigger
with pytest.raises(ValueError):
layer_2.forward(input_values, {"fake_arg": True})


def test_no_layer_kwargs():
# Create the dataset to fit, model(x) is the true data generation process
batch_size = 16
num_points = 10
xs = torch.linspace(0, 10, num_points).repeat(batch_size, 1)
ys = model(xs, torch.ones(batch_size, 2))

layer = create_qf_theseus_layer(
xs,
ys,
nonlinear_optimizer_cls=th.GaussNewton,
linear_solver_cls=th.CholmodSparseSolver,
)
layer.to("cpu")
input_values = {"coefficients": torch.ones(batch_size, 2) * 0.5}

# Trying a few variations of aux_vars. In general, no kwargs should be accepted
# beyong input_data and optimization_kwargs, but I'm not sure how to test for this
with pytest.raises(TypeError):
layer.forward(input_values, aux_vars=None)

with pytest.raises(TypeError):
layer.forward(input_values, aux_variables=None)

with pytest.raises(TypeError):
layer.forward(input_values, auxiliary_vars=None)
Loading

0 comments on commit 58d3c6e

Please sign in to comment.