Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make uniform non-trainable #199

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Make uniform non-trainable
  • Loading branch information
danielward27 committed Dec 18, 2024
commit 27d1d2c24b2b75cb1a424259a7dcd98655426c23
9 changes: 5 additions & 4 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jaxtyping import Array, ArrayLike, PRNGKeyArray, Shaped
from paramax import AbstractUnwrappable, Parameterize, unwrap
from paramax import AbstractUnwrappable, Parameterize, non_trainable, unwrap
from paramax.utils import inv_softplus

from flowjax.bijections import (
Expand Down Expand Up @@ -478,17 +478,18 @@ def __init__(self, minval: ArrayLike, maxval: ArrayLike):
(minval, maxval), maxval <= minval, "minval must be less than the maxval."
)
self.base_dist = _StandardUniform(shape)
self.bijection = Affine(loc=minval, scale=maxval - minval)
self.bijection = non_trainable(Affine(loc=minval, scale=maxval - minval))

@property
def minval(self):
"""Minimum value of the uniform distribution."""
return self.bijection.loc
return unwrap(self.bijection.loc)

@property
def maxval(self):
"""Maximum value of the uniform distribution."""
return self.bijection.loc + unwrap(self.bijection.scale)
unwrapped = unwrap(self)
return unwrapped.loc + unwrapped.scale


class _StandardGumbel(AbstractDistribution):
Expand Down
Loading