From 79de120ce016a539975c243a1bc6b5357003f1e8 Mon Sep 17 00:00:00 2001 From: Vincent Stimper Date: Thu, 19 Jan 2023 19:04:19 +0100 Subject: [PATCH] Added forward and inverse map to flow model, added tests for it --- normflows/core.py | 60 ++++++++++++++++++++++++++++++++++++++++++ normflows/core_test.py | 53 +++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 normflows/core_test.py diff --git a/normflows/core.py b/normflows/core.py index 8c84fb6..d19bd72 100644 --- a/normflows/core.py +++ b/normflows/core.py @@ -24,6 +24,66 @@ def __init__(self, q0, flows, p=None): self.flows = nn.ModuleList(flows) self.p = p + def forward(self, z): + """Transforms latent variable z to the flow variable x + + Args: + z: Batch in the latent space + + Returns: + Batch in the space of the target distribution + """ + for flow in self.flows: + z, _ = flow(z) + return z + + def forward_and_log_det(self, z): + """Transforms latent variable z to the flow variable x and + computes log determinant of the Jacobian + + Args: + z: Batch in the latent space + + Returns: + Batch in the space of the target distribution, + log determinant of the Jacobian + """ + log_det = torch.zeros(len(z), device=z.device) + for flow in self.flows: + z, log_d = flow(z) + log_det -= log_d + return z, log_det + + def inverse(self, x): + """Transforms flow variable x to the latent variable z + + Args: + x: Batch in the space of the target distribution + + Returns: + Batch in the latent space + """ + for i in range(len(self.flows) - 1, -1, -1): + x, _ = self.flows[i].inverse(x) + return x + + def inverse_and_log_det(self, x): + """Transforms flow variable x to the latent variable z and + computes log determinant of the Jacobian + + Args: + x: Batch in the space of the target distribution + + Returns: + Batch in the latent space, log determinant of the + Jacobian + """ + log_det = torch.zeros(len(x), device=x.device) + for i in range(len(self.flows) - 1, -1, -1): + x, log_d = self.flows[i].inverse(x) + log_det += log_d + return x, log_det + def forward_kld(self, x): """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762) diff --git a/normflows/core_test.py b/normflows/core_test.py new file mode 100644 index 0000000..3246f64 --- /dev/null +++ b/normflows/core_test.py @@ -0,0 +1,53 @@ +import unittest +import torch + +from torch.testing import assert_close +from normflows import NormalizingFlow +from normflows.flows import MaskedAffineFlow +from normflows.nets import MLP +from normflows.distributions.base import DiagGaussian +from normflows.distributions.target import CircularGaussianMixture + + +class CoreTest(unittest.TestCase): + def test_mask_affine(self): + batch_size = 5 + latent_size = 2 + for n_layers in [2, 5]: + with self.subTest(n_layers=n_layers): + # Construct flow model + layers = [] + for i in range(n_layers): + b = torch.Tensor([j if i % 2 == j % 2 else 0 for j in range(latent_size)]) + s = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) + t = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) + layers.append(MaskedAffineFlow(b, t, s)) + base = DiagGaussian(latent_size) + target = CircularGaussianMixture() + model = NormalizingFlow(base, layers, target) + inputs = torch.randn((batch_size, latent_size)) + # Test log prob and sampling + log_q = model.log_prob(inputs) + assert log_q.shape == (batch_size,) + s, log_q = model.sample(batch_size) + assert log_q.shape == (batch_size,) + assert s.shape == (batch_size, latent_size) + # Test losses + loss = model.forward_kld(inputs) + assert loss.dim() == 0 + loss = model.reverse_kld(batch_size) + assert loss.dim() == 0 + loss = model.reverse_alpha_div(batch_size) + assert loss.dim() == 0 + # Test forward and inverse + outputs = model.forward(inputs) + inputs_ = model.inverse(outputs) + assert_close(inputs_, inputs) + outputs, log_det = model.forward_and_log_det(inputs) + inputs_, log_det_ = model.inverse_and_log_det(outputs) + assert_close(inputs_, inputs) + assert_close(log_det, -log_det_) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file