Skip to content

Commit

Permalink
Fix broken backward graph when using LieTensor.exp() (#631)
Browse files Browse the repository at this point in the history
* [bug-fix] create_lie_tensor used in LieTensor.exp() was breaking compute graph.

* Fix broken unit tests when cuda is available.

* Add unit test for exp map backward.
  • Loading branch information
luisenp authored Dec 22, 2023
1 parent c68c7f5 commit 52d02cb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
31 changes: 21 additions & 10 deletions tests/torchlie_tests/test_lie_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@pytest.fixture
def rng():
rng_ = torch.Generator()
rng_ = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu")
rng_.manual_seed(0)
return rng_

Expand Down Expand Up @@ -124,14 +124,25 @@ def _to_torch(t):


def test_backward_works():
# Run optimization to check that the compute graph is not broken
# Runs optimization to check that the compute graph is not broken
def _check(opt_tensor, target_tensor, tensor_fn):
opt = torch.optim.Adam([opt_tensor])
losses = []
for i in range(2):
opt.zero_grad()
d = tensor_fn(opt_tensor).local(target_tensor)
loss = torch.sum(d**2)
losses.append(loss.detach().clone())
loss.backward()
opt.step()
assert not losses[0].allclose(losses[-1])

# Check local op from a random tensor
g1 = lie.SE3.rand(1, requires_grad=True)
g2 = lie.SE3.rand(1)
opt = torch.optim.Adam([g1], lr=0.1)
for i in range(10):
opt.zero_grad()
d = g1.local(g2)
loss = torch.sum(d**2)
loss.backward()
opt.step()
print(f"Iter {i}. Loss: {loss.item(): .3f}")
_check(g1, g2, lambda x: x)

# Check local op from exp map
vec = torch.randn((2, 6), requires_grad=True)
eye = lie.SE3.identity(2)
_check(vec, eye, lambda x: lie.SE3.exp(x))
2 changes: 1 addition & 1 deletion torchlie/torchlie/lie_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def fn(tensor: torch.Tensor) -> LieTensor:
SO3.randn = _build_random_fn("randn", SO3)
SO3.identity = _build_identity_fn(SO3)
SO3._call_impl = _build_call_impl(SO3)
SE3._create_lie_tensor = SO3._create_lie_tensor = LieTensor
SE3._create_lie_tensor = SO3._create_lie_tensor = from_tensor


def log(group: LieTensor) -> torch.Tensor:
Expand Down

0 comments on commit 52d02cb

Please sign in to comment.