Skip to content

Commit

Permalink
Added forward and inverse map to flow model, added tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Jan 19, 2023
1 parent a988524 commit 79de120
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
60 changes: 60 additions & 0 deletions normflows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,66 @@ def __init__(self, q0, flows, p=None):
self.flows = nn.ModuleList(flows)
self.p = p

def forward(self, z):
"""Transforms latent variable z to the flow variable x
Args:
z: Batch in the latent space
Returns:
Batch in the space of the target distribution
"""
for flow in self.flows:
z, _ = flow(z)
return z

def forward_and_log_det(self, z):
"""Transforms latent variable z to the flow variable x and
computes log determinant of the Jacobian
Args:
z: Batch in the latent space
Returns:
Batch in the space of the target distribution,
log determinant of the Jacobian
"""
log_det = torch.zeros(len(z), device=z.device)
for flow in self.flows:
z, log_d = flow(z)
log_det -= log_d
return z, log_det

def inverse(self, x):
"""Transforms flow variable x to the latent variable z
Args:
x: Batch in the space of the target distribution
Returns:
Batch in the latent space
"""
for i in range(len(self.flows) - 1, -1, -1):
x, _ = self.flows[i].inverse(x)
return x

def inverse_and_log_det(self, x):
"""Transforms flow variable x to the latent variable z and
computes log determinant of the Jacobian
Args:
x: Batch in the space of the target distribution
Returns:
Batch in the latent space, log determinant of the
Jacobian
"""
log_det = torch.zeros(len(x), device=x.device)
for i in range(len(self.flows) - 1, -1, -1):
x, log_d = self.flows[i].inverse(x)
log_det += log_d
return x, log_det

def forward_kld(self, x):
"""Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Expand Down
53 changes: 53 additions & 0 deletions normflows/core_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
import torch

from torch.testing import assert_close
from normflows import NormalizingFlow
from normflows.flows import MaskedAffineFlow
from normflows.nets import MLP
from normflows.distributions.base import DiagGaussian
from normflows.distributions.target import CircularGaussianMixture


class CoreTest(unittest.TestCase):
def test_mask_affine(self):
batch_size = 5
latent_size = 2
for n_layers in [2, 5]:
with self.subTest(n_layers=n_layers):
# Construct flow model
layers = []
for i in range(n_layers):
b = torch.Tensor([j if i % 2 == j % 2 else 0 for j in range(latent_size)])
s = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True)
t = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True)
layers.append(MaskedAffineFlow(b, t, s))
base = DiagGaussian(latent_size)
target = CircularGaussianMixture()
model = NormalizingFlow(base, layers, target)
inputs = torch.randn((batch_size, latent_size))
# Test log prob and sampling
log_q = model.log_prob(inputs)
assert log_q.shape == (batch_size,)
s, log_q = model.sample(batch_size)
assert log_q.shape == (batch_size,)
assert s.shape == (batch_size, latent_size)
# Test losses
loss = model.forward_kld(inputs)
assert loss.dim() == 0
loss = model.reverse_kld(batch_size)
assert loss.dim() == 0
loss = model.reverse_alpha_div(batch_size)
assert loss.dim() == 0
# Test forward and inverse
outputs = model.forward(inputs)
inputs_ = model.inverse(outputs)
assert_close(inputs_, inputs)
outputs, log_det = model.forward_and_log_det(inputs)
inputs_, log_det_ = model.inverse_and_log_det(outputs)
assert_close(inputs_, inputs)
assert_close(log_det, -log_det_)


if __name__ == "__main__":
unittest.main()

0 comments on commit 79de120

Please sign in to comment.