Skip to content

Commit f19fec5

Browse files
committed
Starts building the StochasticManifold class
1 parent 72946ea commit f19fec5

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

stochman/manifold.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import numpy as np
66
import torch
77
from torch.autograd import grad
8+
from torch.distributions import kl_divergence
89

910
from stochman.curves import BasicCurve, CubicSpline
1011
from stochman.geodesic import geodesic_minimizing_energy, shooting_geodesic
11-
from stochman.utilities import squared_manifold_distance
12+
from stochman.utilities import squared_manifold_distance, kl_by_sampling
1213

1314

1415
class Manifold(ABC):
@@ -661,3 +662,36 @@ def geodesic_system(self, c, dc):
661662

662663
ddc_tensor = torch.cat(ddc, dim=1).t() # NxD
663664
return ddc_tensor
665+
666+
667+
class StochasticManifold(Manifold):
668+
"""
669+
A class for computing Stochastic Manifolds and defining
670+
a metric in a latent space using the Fisher-Rao metric.
671+
"""
672+
673+
def __init__(self, model: torch.nn.Module) -> None:
674+
"""
675+
Class constructor:
676+
677+
Arguments:
678+
- model: a torch module that implements a `decode(z: Tensor) -> Distribution`.
679+
"""
680+
super().__init__()
681+
682+
self.model = model
683+
assert "decode" in dir(model)
684+
685+
def curve_energy(self, curve: BasicCurve) -> torch.Tensor:
686+
dt = (curve[:-1] - curve[1:]).pow(2).sum(dim=-1, keepdim=True) # (N-1)x1
687+
dist1 = self.model.decode(dt[:-1])
688+
dist2 = self.model.decode(dt[1:])
689+
690+
try:
691+
kl = kl_divergence(dist1, dist2)
692+
except Exception:
693+
# TODO: fix the exception. Is there a way of knowing if kl_div is
694+
# implemented for the distribution?
695+
kl = kl_by_sampling(dist1, dist2)
696+
697+
return (kl * dt).sum()

stochman/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
#!/usr/bin/env python
22
from .distance import squared_manifold_distance
3+
from .sampling import kl_by_sampling

stochman/utilities/sampling.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env python3
2+
import torch
3+
from torch.distributions import Distribution
4+
5+
def kl_by_sampling(p: Distribution, q: Distribution, n_samples: int = 1000):
6+
"""
7+
Returns the mean KL by sampling.
8+
"""
9+
x = p.rsample((n_samples,)) # (n_samples)x(output of p)
10+
kl = (p.log_prob(x) - q.log_prob(x)).mean(dim=0).abs() # (output_of_p?)
11+
return kl

0 commit comments

Comments
 (0)