Skip to content

Commit

Permalink
Merge pull request #31 from danielward27/affine
Browse files Browse the repository at this point in the history
Affine
  • Loading branch information
danielward27 authored Sep 23, 2022
2 parents 7cacf73 + 78b5ec9 commit 31feca6
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 7 deletions.
99 changes: 99 additions & 0 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from flowjax.bijections.abc import Bijection
from flowjax.utils import Array
import jax.numpy as jnp
from flowjax.bijections.masked_autoregressive import rank_based_mask
from jax.scipy.linalg import solve_triangular


class Affine(Bijection):
loc: Array
log_scale: Array

def __init__(self, loc: Array, scale: Array):
"""Elementwise affine transformation. Condition is ignored.
Args:
loc (Array): Location parameter vector.
scale (Array): Scale parameter vector.
"""
self.loc = loc
self.log_scale = jnp.log(scale)
self.cond_dim = 0

def transform(self, x, condition = None):
return x * self.scale + self.loc

def transform_and_log_abs_det_jacobian(self, x, condition = None):
return x * self.scale + self.loc, self.log_scale.sum()

def inverse(self, y, condition = None):
return (y - self.loc) / self.scale

def inverse_and_log_abs_det_jacobian(self, y, condition = None):
return (y - self.loc) / self.scale, -self.log_scale.sum()

@property
def scale(self):
return jnp.exp(self.log_scale)


class TriangularAffine(Bijection):
loc: Array
dim: int
cond_dim: int
diag_mask: Array
tri_mask: Array
lower: bool
min_diag: float
_arr: Array
_log_diag: Array

def __init__(self, loc: Array, arr: Array, lower: bool = True, min_diag: float = 1e-6):
"""
Transformation of the form Ax + b, where A is a lower or upper triangular matrix. To
ensure invertiblility, diagonal entries should be positive (and greater than min_diag).
Args:
loc (Array): Translation.
arr (Array): Matrix.
lower (bool, optional): Whether the mask should select the lower or upper triangular matrix (other elements ignored). Defaults to True.
min_diag (float, optional): Minimum value on the diagonal, to ensure invertibility. Defaults to 1e-6.
"""

if (arr.ndim != 2) or (arr.shape[0] != arr.shape[1]):
ValueError("arr must be a square, 2-dimensional matrix.")
if jnp.any(jnp.diag(arr) < min_diag):
ValueError("arr diagonal entries must be greater than min_diag")

self.dim = arr.shape[0]
self.cond_dim = 0
self.diag_mask = jnp.eye(self.dim, dtype=jnp.int32)
tri_mask = jnp.tril(jnp.ones((self.dim, self.dim), dtype=jnp.int32), k=-1)
self.tri_mask = tri_mask if lower else tri_mask.T
self.min_diag = min_diag
self.lower = lower

# inexact arrays
self.loc = loc
self._arr = arr
self._log_diag = jnp.log(jnp.diag(arr) - min_diag)

@property
def arr(self):
"Get triangular array, (applies masking and min_diag constraint)."
diag = self.diag_mask*jnp.exp(self._log_diag) + self.min_diag
return self.tri_mask*self._arr + diag

def transform(self, x, condition = None):
return self.arr @ x + self.loc

def transform_and_log_abs_det_jacobian(self, x, condition = None):
a = self.arr
return a @ x + self.loc, jnp.diag(a).sum()

def inverse(self, y, condition = None):
return solve_triangular(self.arr, y - self.loc, lower=self.lower)

def inverse_and_log_abs_det_jacobian(self, y, condition = None):
a = self.arr
return solve_triangular(a, y - self.loc, lower=self.lower), -jnp.diag(a).sum()
2 changes: 1 addition & 1 deletion flowjax/bijections/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def transform(self, x, loc, scale):
return x * scale + loc

def transform_and_log_abs_det_jacobian(self, x, loc, scale):
return x * scale + loc, jnp.sum(jnp.log(scale))
return x * scale + loc, jnp.log(scale).sum()

def inverse(self, y, loc, scale):
return (y - loc) / scale
Expand Down
14 changes: 14 additions & 0 deletions tests/test_bijections/test_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
from flowjax.bijections.affine import TriangularAffine
import jax.numpy as jnp

def test_TriangularAffine():
"Test mainly to check min_diag does not cause unexpected results on initialisation."
loc = jnp.array([1,2])
lower = jnp.array([[2, 0],[0.5, 3]])

bijection = TriangularAffine(loc, lower)
x = jnp.ones(2)
y = bijection.transform(x)
expected = lower @ x + loc
assert y == pytest.approx(expected)
14 changes: 8 additions & 6 deletions tests/test_bijections/test_invertibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flowjax.bijections.transformers import AffineTransformer
from flowjax.bijections.utils import Flip, Permute
from flowjax.bijections.transformers import AffineTransformer, RationalQuadraticSplineTransformer
from flowjax.bijections.affine import Affine, TriangularAffine

transformers = {
"AffineTransformer": AffineTransformer(),
Expand All @@ -32,16 +33,17 @@ def test_transformer_invertibility(bijection):
assert log_det1 == pytest.approx(-log_det2, abs=1e-5)






dim = 5
cond_dim = 2
key = random.PRNGKey(0)
pos_def_triangles = jnp.full((dim,dim), 0.5) + jnp.diag(jnp.ones(dim))

bijections = {
"Flip": Flip(),
"Permute": Permute(jnp.flip(jnp.arange(dim))),
"Affine": Affine(jnp.ones(dim), jnp.full(dim, 2)),
"TriangularAffine (lower)": TriangularAffine(jnp.arange(dim), pos_def_triangles),
"TriangularAffine (upper)": TriangularAffine(jnp.arange(dim), pos_def_triangles, lower=False),
"Coupling (unconditional)": Coupling(
key,
AffineTransformer(),
Expand Down Expand Up @@ -71,7 +73,7 @@ def test_transformer_invertibility(bijection):
),
"MaskedAutoregressive_RationalQuadraticSpline (conditional)": MaskedAutoregressive(
key, RationalQuadraticSplineTransformer(5, 3), cond_dim=cond_dim, dim=dim, nn_width=10, nn_depth=2
)
),
}


Expand All @@ -89,5 +91,5 @@ def test_bijection_invertibility(bijection):
y, log_det1 = bijection.transform_and_log_abs_det_jacobian(x, cond)
x_reconstructed, log_det2 = bijection.inverse_and_log_abs_det_jacobian(y, cond)

assert x == pytest.approx(x_reconstructed, abs=1e-5) # Check invertibility
assert x == pytest.approx(x_reconstructed, abs=1e-5)
assert log_det1 == pytest.approx(-log_det2, abs=1e-5)

0 comments on commit 31feca6

Please sign in to comment.