Skip to content

Commit

Permalink
Fix autoregressive masking
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed May 16, 2024
1 parent 039e40b commit 498ff88
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions flowjax/bijections/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def __init__(
if cond_dim is None:
self.cond_shape = None
in_ranks = jnp.arange(dim)
hidden_ranks = jnp.arange(nn_width) % (dim - 1)
# If dim=1, hidden ranks all zero -> all weights masked out in final layer
else:
self.cond_shape = (cond_dim,)
# we give conditioning variables rank -1 (no masking of edges to output)
in_ranks = jnp.hstack((jnp.arange(dim), -jnp.ones(cond_dim, int)))

hidden_ranks = jnp.arange(nn_width) % dim
hidden_ranks = (jnp.arange(nn_width) % dim) - 1
out_ranks = jnp.repeat(jnp.arange(dim), num_params)

self.masked_autoregressive_mlp = masked_autoregressive_mlp(
Expand Down
4 changes: 2 additions & 2 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def log_prob(self, x: ArrayLike, condition: ArrayLike | None = None) -> Array:
Array: Jax array of log probabilities.
"""
self = unwrap(self)
x = arraylike_to_array(x, err_name="x")
x = arraylike_to_array(x, err_name="x", dtype=float)
if self.cond_shape is not None:
condition = arraylike_to_array(condition, err_name="condition")
condition = arraylike_to_array(condition, err_name="condition", dtype=float)
lps = self._vectorize(self._log_prob)(x, condition)
return jnp.where(jnp.isnan(lps), -jnp.inf, lps)

Expand Down

0 comments on commit 498ff88

Please sign in to comment.