diff --git a/tests/torchlie_tests/test_lie_tensor.py b/tests/torchlie_tests/test_lie_tensor.py index d222aa01..49350947 100644 --- a/tests/torchlie_tests/test_lie_tensor.py +++ b/tests/torchlie_tests/test_lie_tensor.py @@ -13,7 +13,7 @@ @pytest.fixture def rng(): - rng_ = torch.Generator() + rng_ = torch.Generator(device="cuda:0" if torch.cuda.is_available() else "cpu") rng_.manual_seed(0) return rng_ @@ -124,14 +124,25 @@ def _to_torch(t): def test_backward_works(): - # Run optimization to check that the compute graph is not broken + # Runs optimization to check that the compute graph is not broken + def _check(opt_tensor, target_tensor, tensor_fn): + opt = torch.optim.Adam([opt_tensor]) + losses = [] + for i in range(2): + opt.zero_grad() + d = tensor_fn(opt_tensor).local(target_tensor) + loss = torch.sum(d**2) + losses.append(loss.detach().clone()) + loss.backward() + opt.step() + assert not losses[0].allclose(losses[-1]) + + # Check local op from a random tensor g1 = lie.SE3.rand(1, requires_grad=True) g2 = lie.SE3.rand(1) - opt = torch.optim.Adam([g1], lr=0.1) - for i in range(10): - opt.zero_grad() - d = g1.local(g2) - loss = torch.sum(d**2) - loss.backward() - opt.step() - print(f"Iter {i}. Loss: {loss.item(): .3f}") + _check(g1, g2, lambda x: x) + + # Check local op from exp map + vec = torch.randn((2, 6), requires_grad=True) + eye = lie.SE3.identity(2) + _check(vec, eye, lambda x: lie.SE3.exp(x)) diff --git a/torchlie/torchlie/lie_tensor.py b/torchlie/torchlie/lie_tensor.py index dff790be..edf45d42 100644 --- a/torchlie/torchlie/lie_tensor.py +++ b/torchlie/torchlie/lie_tensor.py @@ -552,7 +552,7 @@ def fn(tensor: torch.Tensor) -> LieTensor: SO3.randn = _build_random_fn("randn", SO3) SO3.identity = _build_identity_fn(SO3) SO3._call_impl = _build_call_impl(SO3) -SE3._create_lie_tensor = SO3._create_lie_tensor = LieTensor +SE3._create_lie_tensor = SO3._create_lie_tensor = from_tensor def log(group: LieTensor) -> torch.Tensor: