Skip to content

Commit

Permalink
Merge pull request #167 from danielward27/sample_contrastive
Browse files Browse the repository at this point in the history
Sample contrastive
  • Loading branch information
danielward27 authored Jul 24, 2024
2 parents bdae1ba + 20d1383 commit eeb3481
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 20 deletions.
10 changes: 5 additions & 5 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion flowjax/train/data_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,24 @@ def fit_to_data(
# Train epoch
batch_losses = []
for batch in zip(*get_batches(train_data, batch_size), strict=True):
key, subkey = jr.split(key)
params, opt_state, loss_i = step(
params,
static,
*batch,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
key=subkey,
)
batch_losses.append(loss_i)
losses["train"].append(sum(batch_losses) / len(batch_losses))

# Val epoch
batch_losses = []
for batch in zip(*get_batches(val_data, batch_size), strict=True):
loss_i = loss_fn(params, static, *batch)
key, subkey = jr.split(key)
loss_i = loss_fn(params, static, *batch, key=subkey)
batch_losses.append(loss_i)
losses["val"].append(sum(batch_losses) / len(batch_losses))

Expand Down
40 changes: 28 additions & 12 deletions flowjax/train/losses.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Common loss functions for training normalizing flows.
The loss functions are callables, with the first two arguments being the partitioned
distribution (see ``equinox.partition``).
In order to be compatible with ``fit_to_data``, the loss function arguments must match
``(params, static, x, condition, key)``, where ``params`` and ``static`` are the
partitioned model (see ``equinox.partition``).
For ``fit_to_variational_target``, the loss function signature must match
``(params, static, key)``.
"""

from collections.abc import Callable

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
from jax.lax import stop_gradient
from jax.scipy.special import logsumexp
Expand All @@ -30,8 +35,9 @@ def __call__(
static: AbstractDistribution,
x: Array,
condition: Array | None = None,
key: PRNGKeyArray | None = None,
) -> Float[Array, ""]:
"""Compute the loss."""
"""Compute the loss. Key is ignored (for consistency of API)."""
dist = unwrap(eqx.combine(params, static))
return -dist.log_prob(x, condition).mean()

Expand All @@ -52,7 +58,7 @@ class ContrastiveLoss:
prior: The prior distribution over x (the target
variable).
n_contrastive: The number of contrastive samples/atoms to use when
computing the loss.
computing the loss. Must be less than ``batch_size``.
References:
- https://arxiv.org/abs/1905.07488
Expand All @@ -69,30 +75,40 @@ def __call__(
params: AbstractDistribution,
static: AbstractDistribution,
x: Float[Array, "..."],
condition: Array | None = None,
condition: Array | None,
key: PRNGKeyArray,
) -> Float[Array, ""]:
"""Compute the loss."""
if x.shape[0] <= self.n_contrastive:
raise ValueError(
f"Number of contrastive samples {self.n_contrastive} must be less than "
f"the size of x {x.shape}.",
)

dist = unwrap(eqx.combine(params, static))

def single_x_loss(x_i, condition_i, idx):
def single_x_loss(x_i, condition_i, contrastive_idxs):
positive_logit = dist.log_prob(x_i, condition_i) - self.prior.log_prob(x_i)
contrastive = jnp.delete(x, idx, assume_unique_indices=True, axis=0)[
: self.n_contrastive
]
contrastive = x[contrastive_idxs]
contrastive_logits = dist.log_prob(
contrastive, condition_i
) - self.prior.log_prob(contrastive)
normalizer = logsumexp(jnp.append(contrastive_logits, positive_logit))
return -(positive_logit - normalizer)

return eqx.filter_vmap(single_x_loss)(
x, condition, jnp.arange(x.shape[0], dtype=int)
).mean()
contrastive_idxs = _get_contrastive_idxs(key, x.shape[0], self.n_contrastive)
return eqx.filter_vmap(single_x_loss)(x, condition, contrastive_idxs).mean()


def _get_contrastive_idxs(key: PRNGKeyArray, batch_size: int, n_contrastive: int):

@eqx.filter_vmap
def _get_idxs(key, idx, batch_size, n_contrastive):
choices = jnp.delete(jnp.arange(batch_size), idx, assume_unique_indices=True)
return jr.choice(key, choices, (n_contrastive,), replace=False)

keys = jr.split(key, batch_size)
return _get_idxs(keys, jnp.arange(batch_size), batch_size, n_contrastive)


class ElboLoss:
Expand Down
9 changes: 8 additions & 1 deletion flowjax/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def step(
optimizer: optax.GradientTransformation,
opt_state: PyTree,
loss_fn: Callable[[PyTree, PyTree], Scalar],
**kwargs,
):
"""Carry out a training step.
Expand All @@ -30,11 +31,17 @@ def step(
opt_state: Optimizer state.
loss_fn: The loss function. This should take params and static as the first two
arguments.
**kwargs: Key word arguments passed to the loss function.
Returns:
tuple: (params, opt_state, loss_val)
"""
loss_val, grads = eqx.filter_value_and_grad(loss_fn)(params, static, *args)
loss_val, grads = eqx.filter_value_and_grad(loss_fn)(
params,
static,
*args,
**kwargs,
)
updates, opt_state = optimizer.update(grads, opt_state, params=params)
params = eqx.apply_updates(params, updates)
return params, opt_state, loss_val
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.5.0"
version = "13.0.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_train/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import jax.numpy as jnp
import jax.random as jr

from flowjax.train.losses import _get_contrastive_idxs


def test_get_contrastive_idxs():
key = jr.PRNGKey(0)
batch_size = 5

for _ in range(5):
key, subkey = jr.split(key)
idxs = _get_contrastive_idxs(subkey, batch_size=batch_size, n_contrastive=4)
for i, row in enumerate(idxs):
assert i not in row

assert jnp.all(idxs < batch_size)

0 comments on commit eeb3481

Please sign in to comment.