Skip to content

Commit

Permalink
Objective.to() and TheseusLayer.to() now return self and suppor…
Browse files Browse the repository at this point in the history
…t "cuda". (#623)

* Changed Objective.to() and TheseusLayer.to() to return self.

* Add support for objective.to('cuda').
  • Loading branch information
luisenp authored Nov 16, 2023
1 parent d42be9a commit 322807e
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 23 deletions.
3 changes: 1 addition & 2 deletions examples/homography_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ def run(
max_iterations=max_iterations,
step_size=step_size,
)
theseus_layer = th.TheseusLayer(inner_optim)
theseus_layer.to(device)
theseus_layer = th.TheseusLayer(inner_optim).to(device)

# Set up outer loop optimization.
outer_optim = torch.optim.Adam(cnn_model.parameters(), lr=outer_lr)
Expand Down
3 changes: 1 addition & 2 deletions examples/pose_graph/pose_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ def main(cfg):
)
objective.add(pose_prior)

objective.to(dtype)
optimizer = th.LevenbergMarquardt(
objective,
objective.to(dtype),
max_iterations=10,
step_size=1,
linearization_cls=th.SparseLinearization,
Expand Down
4 changes: 1 addition & 3 deletions examples/pose_graph/pose_graph_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ def run(

objective.add(pose_prior_cost)

objective.to(cfg.device)

linear_solver_cls: Type[LinearSolver] = cast(
Type[LinearSolver],
th.LUCudaSparseSolver
if cast(str, cfg.solver_device) == "cuda"
else th.CholmodSparseSolver,
)
optimizer = th.GaussNewton(
objective,
objective.to(cfg.device),
max_iterations=cfg.inner_optim.max_iters,
step_size=cfg.inner_optim.step_size,
abs_err_tolerance=0,
Expand Down
3 changes: 1 addition & 2 deletions examples/pose_graph/pose_graph_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,8 @@ def run(cfg: omegaconf.OmegaConf):
th, cfg.inner_optim.optimizer_cls
)

objective.to(device)
optimizer = optimizer_cls(
objective,
objective.to(device),
max_iterations=cfg.inner_optim.max_iters,
step_size=cfg.inner_optim.step_size,
linear_solver_cls=getattr(th, cfg.inner_optim.linear_solver_cls),
Expand Down
19 changes: 18 additions & 1 deletion tests/theseus_tests/core/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,21 @@ def test_to_dtype():
for aux in cf.aux_vars:
assert aux.dtype == dtype


def test_to_device():
if not torch.cuda.is_available():
return
objective, *_ = create_objective_with_mock_cost_functions()
for device in ["cuda", "cuda:0"]:
dummy = torch.zeros(1, device=device)
objective.to(device)
for _, cf in objective.cost_functions.items():
for var in cf.optim_vars:
assert var.device == dummy.device
for aux in cf.aux_vars:
assert var.device == dummy.device


def test_cost_delete_and_add():
x = th.Variable(torch.zeros(2), name="x")
y = th.Variable(torch.zeros(3), name="y")
Expand All @@ -557,7 +572,9 @@ def error_fn(optim_vars, aux_vars):

objective = th.Objective()
assert len(objective.aux_vars) == 0
cost_function = th.AutoDiffCostFunction([x], error_fn, 1, aux_vars=[y], cost_weight=th.ScaleCostWeight(1.0))
cost_function = th.AutoDiffCostFunction(
[x], error_fn, 1, aux_vars=[y], cost_weight=th.ScaleCostWeight(1.0)
)

# Add a cost function, erase it, and add it again to make sure we don't have any bugs.
objective.add(cost_function)
Expand Down
3 changes: 1 addition & 2 deletions tests/theseus_tests/core/test_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ def _solve_fn_for_masked_jacobians(
layer = th.TheseusLayer(
th.LevenbergMarquardt(obj, step_size=0.1, max_iterations=5),
vectorize=vectorize,
)
layer.to(device=device)
).to(device=device)
sol, _ = layer.forward(input_tensors)

# Check that we can backprop through this without errors
Expand Down
3 changes: 1 addition & 2 deletions tests/theseus_tests/test_dlm_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def test_backward_pass_se3_runs():

objective = th.Objective()
objective.add(th.Difference(var, target, th.ScaleCostWeight(1.0)))
objective.to(dtype=dtype)
optimizer = th.GaussNewton(objective)
optimizer = th.GaussNewton(objective.to(dtype=dtype))
layer = th.TheseusLayer(optimizer)

target_data = torch.nn.Parameter(th.rand_se3(batch_size, dtype=dtype).tensor)
Expand Down
3 changes: 1 addition & 2 deletions tests/theseus_tests/test_theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,7 @@ def test_pass_optimizer_kwargs():
ys,
nonlinear_optimizer_cls=th.GaussNewton,
linear_solver_cls=th.CholmodSparseSolver,
)
layer.to("cpu")
).to("cpu")
input_values = {"coefficients": torch.ones(batch_size, 2) * 0.5}
for tbs in [True, False]:
_, info = layer.forward(
Expand Down
11 changes: 7 additions & 4 deletions theseus/core/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def __init__(

self._allow_mixed_optim_aux_vars = __allow_mixed_optim_aux_vars__

self.__tensor_ref = torch.empty(1, device=self.device, dtype=self.dtype)

def _add_function_variables(
self,
function: TheseusFunction,
Expand Down Expand Up @@ -841,15 +843,16 @@ def _get_jacobians_iter(self) -> Iterable:
# No vectorization is used, just serve from cost functions
return iter(cf for cf in self.cost_functions.values())

def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> "Objective":
"""Applies torch.Tensor.to() to all cost functions in the objective."""
for cost_function in self.cost_functions.values():
cost_function.to(*args, **kwargs)
device, dtype, *_ = torch._C._nn._parse_to(*args, **kwargs)
self.device = device or self.device
self.dtype = dtype or self.dtype
self.__tensor_ref = self.__tensor_ref.to(*args, **kwargs)
self.device = self.__tensor_ref.device
self.dtype = self.__tensor_ref.dtype
if self._vectorization_to is not None:
self._vectorization_to(*args, **kwargs)
return self

@staticmethod
def _retract_base(
Expand Down
3 changes: 2 additions & 1 deletion theseus/theseus_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def compute_samples(
return x_samples

# Applies to() with given args to all tensors in the objective
def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> "TheseusLayer":
super().to(*args, **kwargs)
self.objective.to(*args, **kwargs)
return self

@property
def device(self) -> DeviceType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def __init__(
abs_err_tolerance=0 if force_max_iters else 1e-10,
rel_err_tolerance=0 if force_max_iters else 1e-8,
)
self.theseus_layer = th.TheseusLayer(nl_optimizer)
self.theseus_layer.to(device=device, dtype=torch.double)
self.theseus_layer = th.TheseusLayer(nl_optimizer).to(
device=device, dtype=torch.double
)

self.forward = self.theseus_layer.forward

Expand Down

0 comments on commit 322807e

Please sign in to comment.