Skip to content

Commit

Permalink
Merge pull request #52 from ardizzone/ibinn_improvements
Browse files Browse the repository at this point in the history
Improvements to FrEIA that came about from the IB-INN paper
  • Loading branch information
wapu authored Feb 2, 2021
2 parents 979f6b1 + cf4e885 commit 033f6fd
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 26 deletions.
1 change: 0 additions & 1 deletion FrEIA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
structure of operations.'''
from . import framework
from . import modules
from . import dummy_modules

__all__ = ["framework", "modules"]
16 changes: 16 additions & 0 deletions FrEIA/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
'''The framework module contains the logic used in building the graph and
inferring the order that the nodes have to be executed in forward and backward
direction.'''

from .reversible_graph_net import *
from .reversible_sequential_net import *

__all__ = [
'ReversibleSequential',
'ReversibleGraphNet',
'Node',
'InputNode',
'ConditionNode',
'OutputNode'
]

File renamed without changes.
24 changes: 0 additions & 24 deletions FrEIA/framework.py → FrEIA/framework/reversible_graph_net.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
'''The framework module contains the logic used in building the graph and
inferring the order that the nodes have to be executed in forward and backward
direction.'''

import sys
import warnings
import numpy as np
Expand Down Expand Up @@ -499,23 +495,3 @@ def get_module_by_name(self, name):
return node.module
except:
return None



# Testing example
if __name__ == '__main__':
inp = InputNode(4, 64, 64, name='input')
t1 = Node([(inp, 0)], dummys.dummy_mux, {}, name='t1')
s1 = Node([(t1, 0)], dummys.dummy_2split, {}, name='s1')

t2 = Node([(s1, 0)], dummys.dummy_module, {}, name='t2')
s2 = Node([(s1, 1)], dummys.dummy_2split, {}, name='s2')
t3 = Node([(s2, 0)], dummys.dummy_module, {}, name='t3')

m1 = Node([(t3, 0), (s2, 1)], dummys.dummy_2merge, {}, name='m1')
m2 = Node([(t2, 0), (m1, 0)], dummys.dummy_2merge, {}, name='m2')
outp = OutputNode([(m2, 0)], name='output')

all_nodes = [inp, outp, t1, s1, t2, s2, t3, m1, m2]

net = ReversibleGraphNet(all_nodes, 0, 1)
76 changes: 76 additions & 0 deletions FrEIA/framework/reversible_sequential_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch.nn as nn
import torch

class ReversibleSequential(nn.Module):
'''Simpler than FrEIA.framework.ReversibleGraphNet:
Only supports a sequential series of modules (no splitting, merging, branching off).
Has an append() method, to add new blocks in a more simple way than the computation-graph
based approach of ReversibleGraphNet. For example:
inn = ReversibleSequential(channels, dims_H, dims_W)
for i in range(n_blocks):
inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True)
inn.append(FrEIA.modules.HaarDownsampling)
# and so on
'''

def __init__(self, *dims):
super().__init__()

self.shapes = [tuple(dims)]
self.conditions = []
self.module_list = nn.ModuleList()

def append(self, module_class, cond=None, cond_shape=None, **kwargs):
'''Append a reversible block from FrEIA.modules to the network.
module_class: Class from FrEIA.modules.
cond (int): index of which condition to use (conditions will be passed as list to forward()).
Conditioning nodes are not needed for ReversibleSequential.
cond_shape (tuple[int]): the shape of the condition tensor.
**kwargs: Further keyword arguments that are passed to the constructor of module_class (see example).
'''

dims_in = [self.shapes[-1]]
self.conditions.append(cond)

if cond is not None:
kwargs['dims_c'] = [cond_shape]

module = module_class(dims_in, **kwargs)
self.module_list.append(module)
ouput_dims = module.output_dims(dims_in)
assert len(ouput_dims) == 1, "Module has more than one output"
self.shapes.append(ouput_dims[0])


def forward(self, x, c=None, rev=False):
'''
x (Tensor): input tensor (in contrast to ReversibleGraphNet, a list of tensors is not
supported, as ReversibleSequential only has one input).
c (list[Tensor]): list of conditions.
rev: whether to compute the network forward or reversed.
Returns
z (Tensor): network output.
jac (Tensor): log-jacobian-determinant.
There is no separate log_jacobian() method, it is automatically computed during forward().
'''

iterator = range(len(self.module_list))
jac = 0

if rev:
iterator = reversed(iterator)

for i in iterator:
if self.conditions[i] is None:
x, j = (self.module_list[i]([x], rev=rev)[0],
self.module_list[i].jacobian(x, rev=rev))
else:
x, j = (self.module_list[i]([x], c=[c[self.conditions[i]]], rev=rev)[0],
self.module_list[i].jacobian(x, c=[c[self.conditions[i]]], rev=rev))
jac = j + jac

return x, jac
4 changes: 3 additions & 1 deletion FrEIA/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Coupling blocks:
* AllInOneBlock
* NICECouplingBlock
* RNVPCouplingBlock
* GLOWCouplingBlock
Expand Down Expand Up @@ -33,7 +34,6 @@
Graph topology:
* SplitChannel
* ConcatChannel
* Split1D
Expand All @@ -50,6 +50,7 @@
'''

from .all_in_one_block import *
from .fixed_transforms import *
from .reshapes import *
from .coupling_layers import *
Expand All @@ -61,6 +62,7 @@
from .gaussian_mixture import *

__all__ = [
'AllInOneBlock',
'glow_coupling_layer',
'rev_layer',
'rev_multiplicative_layer',
Expand Down
212 changes: 212 additions & 0 deletions FrEIA/modules/all_in_one_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import pdb
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import special_ortho_group

class AllInOneBlock(nn.Module):
''' Combines affine coupling, permutation, global affine transformation ('ActNorm')
in one block.'''

def __init__(self, dims_in, dims_c=[],
subnet_constructor=None,
affine_clamping=2.,
gin_block=False,
global_affine_init=1.,
global_affine_type='SOFTPLUS',
permute_soft=False,
learned_householder_permutation=0,
reverse_permutation=False):
'''
subnet_constructor: class or callable f, called as
f(channels_in, channels_out) and should return a torch.nn.Module
affine_clamping: clamp the output of the mutliplicative coefficients
(before exponentiation) to +/- affine_clamping.
gin_block: Turn the block into a GIN block from Sorrenson et al, 2019
global_affine_init: Initial value for the global affine scaling beta
global_affine_init: 'SIGMOID', 'SOFTPLUS', or 'EXP'. Defines the activation
to be used on the beta for the global affine scaling.
permute_soft: bool, whether to sample the permutation matrices from SO(N),
or to use hard permutations in stead. Note, permute_soft=True is very slow
when working with >512 dimensions.
learned_householder_permutation: Int, if >0, use that many learned householder
reflections. Slow if large number. Dubious whether it actually helps.
reverse_permutation: Reverse the permutation before the block, as introduced by
Putzky et al, 2019.
'''

super().__init__()

channels = dims_in[0][0]
self.Ddim = len(dims_in[0]) - 1
self.sum_dims = tuple(range(1, 2 + self.Ddim))

if len(dims_c) == 0:
self.conditional = False
self.condition_channels = 0
elif len(dims_c) == 1:
self.conditional = True
self.condition_channels = dims_c[0][0]
assert tuple(dims_c[0][1:]) == tuple(dims_in[0][1:]), \
F"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}."
else:
raise ValueError('Only supports one condition (concatenate externally)')

split_len1 = channels - channels // 2
split_len2 = channels // 2
self.splits = [split_len1, split_len2]


try:
self.permute_function = {0 : F.linear,
1 : F.conv1d,
2 : F.conv2d,
3 : F.conv3d}[self.Ddim]
except KeyError:
raise ValueError(f"Data has {1 + self.Ddim} dimensions. Must be 1-4.")

self.in_channels = channels
self.clamp = affine_clamping
self.GIN = gin_block
self.welling_perm = reverse_permutation
self.householder = learned_householder_permutation

if permute_soft and channels > 512:
warnings.warn(("Soft permutation will take a very long time to initialize "
f"with {channels} feature channels. Consider using hard permutation instead."))

if global_affine_type == 'SIGMOID':
global_scale = np.log(global_affine_init)
self.global_scale_activation = (lambda a: 10 * torch.sigmoid(a - 2.))
elif global_affine_type == 'SOFTPLUS':
global_scale = 10. * global_affine_init
self.softplus = nn.Softplus(beta=0.5)
self.global_scale_activation = (lambda a: 0.1 * self.softplus(a))
elif global_affine_type == 'EXP':
global_scale = np.log(global_affine_init)
self.global_scale_activation = (lambda a: torch.exp(a))
else:
raise ValueError('Please, SIGMOID, SOFTPLUS or EXP, as global affine type')

self.global_scale = nn.Parameter(torch.ones(1, self.in_channels, *([1] * self.Ddim)) * float(global_scale))
self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.Ddim)))

if permute_soft:
w = special_ortho_group.rvs(channels)
else:
w = np.zeros((channels,channels))
for i,j in enumerate(np.random.permutation(channels)):
w[i,j] = 1.

if self.householder:
self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True)
self.w = None
self.w_inv = None
self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False)
else:
self.w = nn.Parameter(torch.FloatTensor(w).view(channels, channels, *([1] * self.Ddim)),
requires_grad=False)
self.w_inv = nn.Parameter(torch.FloatTensor(w.T).view(channels, channels, *([1] * self.Ddim)),
requires_grad=False)

self.s = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1])
self.last_jac = None

def construct_householder_permutation(self):
w = self.w_0
for vk in self.vk_householder:
w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk))

for i in range(self.Ddim):
w = w.unsqueeze(-1)
return w

def log_e(self, s):
s = self.clamp * torch.tanh(0.1 * s)
if self.GIN:
s -= torch.mean(s, dim=self.sum_dims, keepdim=True)
return s

def permute(self, x, rev=False):
if self.GIN:
scale = 1.
else:
scale = self.global_scale_activation( self.global_scale)
if rev:
return (self.permute_function(x, self.w_inv) - self.global_offset) / scale
else:
return self.permute_function(x * scale + self.global_offset, self.w)

def pre_permute(self, x, rev=False):
if rev:
return self.permute_function(x, self.w)
else:
return self.permute_function(x, self.w_inv)

def affine(self, x, a, rev=False):
ch = x.shape[1]
sub_jac = self.log_e(a[:,:ch])
if not rev:
return (x * torch.exp(sub_jac) + 0.1 * a[:,ch:],
torch.sum(sub_jac, dim=self.sum_dims))
else:
return ((x - 0.1 * a[:,ch:]) * torch.exp(-sub_jac),
-torch.sum(sub_jac, dim=self.sum_dims))

def forward(self, x, c=[], rev=False):
if self.householder:
self.w = self.construct_householder_permutation()
if rev or self.welling_perm:
self.w_inv = self.w.transpose(0,1).contiguous()

if rev:
x = [self.permute(x[0], rev=True)]
elif self.welling_perm:
x = [self.pre_permute(x[0], rev=False)]

x1, x2 = torch.split(x[0], self.splits, dim=1)

if self.conditional:
x1c = torch.cat([x1, *c], 1)
else:
x1c = x1

if not rev:
a1 = self.s(x1c)
x2, j2 = self.affine(x2, a1)
else:
# names of x and y are swapped!
a1 = self.s(x1c)
x2, j2 = self.affine(x2, a1, rev=True)

self.last_jac = j2
x_out = torch.cat((x1, x2), 1)

n_pixels = 1
for d in self.sum_dims[1:]:
n_pixels *= x_out.shape[d]

self.last_jac += ((-1)**rev * n_pixels) * (1 - int(self.GIN)) * (torch.log(self.global_scale_activation(self.global_scale) + 1e-12).sum())

if not rev:
x_out = self.permute(x_out, rev=False)
elif self.welling_perm:
x_out = self.pre_permute(x_out, rev=True)

return [x_out]

def jacobian(self, x, c=[], rev=False):
return self.last_jac

def output_dims(self, input_dims):
return input_dims

0 comments on commit 033f6fd

Please sign in to comment.