Skip to content

Commit

Permalink
Test for coupling layers added
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Nov 6, 2022
1 parent 7cda812 commit ce1b8c5
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
46 changes: 46 additions & 0 deletions normflows/flows/affine/coupling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest
import torch

from torch.testing import assert_close
from normflows.flows import MaskedAffineFlow, CCAffineConst
from normflows.nets import MLP
from normflows.flows.flow_test import FlowTest


class CouplingTest(FlowTest):
def test_mask_affine(self):
batch_size = 5
for latent_size in [2, 7]:
with self.subTest(latent_size=latent_size):
b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)])
s = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True)
t = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True)
flow = MaskedAffineFlow(b, t, s)
inputs = torch.randn((batch_size, latent_size))
self.checkForwardInverse(flow, inputs)

def test_cc_affine(self):
batch_size = 5
for shape in [(5,), (2, 3, 4)]:
for num_classes in [2, 5]:
with self.subTest(shape=shape, num_classes=num_classes):
flow = CCAffineConst(shape, num_classes)
x = torch.randn((batch_size,) + shape)
y = torch.rand((batch_size,) + (num_classes,))
x_, log_det = flow(x, y)
x__, log_det_ = flow(x_, y)

assert x_.dtype == x.dtype
assert x__.dtype == x.dtype

assert x_.shape == x.shape
assert x__.shape == x.shape

assert_close(x__, x)
id_ld = log_det + log_det_
assert_close(id_ld, torch.zeros_like(id_ld))



if __name__ == "__main__":
unittest.main()
1 change: 0 additions & 1 deletion normflows/flows/affine/glow_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest

import torch

from normflows.flows import GlowBlock
Expand Down
2 changes: 1 addition & 1 deletion normflows/flows/flow_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

import torch

from torch.testing import assert_close


Expand Down

0 comments on commit ce1b8c5

Please sign in to comment.