|
5 | 5 | import numpy as np
|
6 | 6 | import torch
|
7 | 7 | from torch.autograd import grad
|
| 8 | +from torch.distributions import kl_divergence |
8 | 9 |
|
9 | 10 | from stochman.curves import BasicCurve, CubicSpline
|
10 | 11 | from stochman.geodesic import geodesic_minimizing_energy, shooting_geodesic
|
11 |
| -from stochman.utilities import squared_manifold_distance |
| 12 | +from stochman.utilities import squared_manifold_distance, kl_by_sampling |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class Manifold(ABC):
|
@@ -661,3 +662,36 @@ def geodesic_system(self, c, dc):
|
661 | 662 |
|
662 | 663 | ddc_tensor = torch.cat(ddc, dim=1).t() # NxD
|
663 | 664 | return ddc_tensor
|
| 665 | + |
| 666 | + |
| 667 | +class StochasticManifold(Manifold): |
| 668 | + """ |
| 669 | + A class for computing Stochastic Manifolds and defining |
| 670 | + a metric in a latent space using the Fisher-Rao metric. |
| 671 | + """ |
| 672 | + |
| 673 | + def __init__(self, model: torch.nn.Module) -> None: |
| 674 | + """ |
| 675 | + Class constructor: |
| 676 | +
|
| 677 | + Arguments: |
| 678 | + - model: a torch module that implements a `decode(z: Tensor) -> Distribution`. |
| 679 | + """ |
| 680 | + super().__init__() |
| 681 | + |
| 682 | + self.model = model |
| 683 | + assert "decode" in dir(model) |
| 684 | + |
| 685 | + def curve_energy(self, curve: BasicCurve) -> torch.Tensor: |
| 686 | + dt = (curve[:-1] - curve[1:]).pow(2).sum(dim=-1, keepdim=True) # (N-1)x1 |
| 687 | + dist1 = self.model.decode(dt[:-1]) |
| 688 | + dist2 = self.model.decode(dt[1:]) |
| 689 | + |
| 690 | + try: |
| 691 | + kl = kl_divergence(dist1, dist2) |
| 692 | + except Exception: |
| 693 | + # TODO: fix the exception. Is there a way of knowing if kl_div is |
| 694 | + # implemented for the distribution? |
| 695 | + kl = kl_by_sampling(dist1, dist2) |
| 696 | + |
| 697 | + return (kl * dt).sum() |
0 commit comments