From 498ff8884f2d8876c1322b801db00dbd048c26ad Mon Sep 17 00:00:00 2001 From: danielward27 Date: Thu, 16 May 2024 13:28:55 +0100 Subject: [PATCH] Fix autoregressive masking --- flowjax/bijections/masked_autoregressive.py | 5 +++-- flowjax/distributions.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/flowjax/bijections/masked_autoregressive.py b/flowjax/bijections/masked_autoregressive.py index 4186f34f..65b63278 100644 --- a/flowjax/bijections/masked_autoregressive.py +++ b/flowjax/bijections/masked_autoregressive.py @@ -63,12 +63,13 @@ def __init__( if cond_dim is None: self.cond_shape = None in_ranks = jnp.arange(dim) + hidden_ranks = jnp.arange(nn_width) % (dim - 1) + # If dim=1, hidden ranks all zero -> all weights masked out in final layer else: self.cond_shape = (cond_dim,) # we give conditioning variables rank -1 (no masking of edges to output) in_ranks = jnp.hstack((jnp.arange(dim), -jnp.ones(cond_dim, int))) - - hidden_ranks = jnp.arange(nn_width) % dim + hidden_ranks = (jnp.arange(nn_width) % dim) - 1 out_ranks = jnp.repeat(jnp.arange(dim), num_params) self.masked_autoregressive_mlp = masked_autoregressive_mlp( diff --git a/flowjax/distributions.py b/flowjax/distributions.py index e525a994..aefa23a1 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -87,9 +87,9 @@ def log_prob(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array: Array: Jax array of log probabilities. """ self = unwrap(self) - x = arraylike_to_array(x, err_name="x") + x = arraylike_to_array(x, err_name="x", dtype=float) if self.cond_shape is not None: - condition = arraylike_to_array(condition, err_name="condition") + condition = arraylike_to_array(condition, err_name="condition", dtype=float) lps = self._vectorize(self._log_prob)(x, condition) return jnp.where(jnp.isnan(lps), -jnp.inf, lps)