Skip to content

Commit

Permalink
Merge pull request #165 from danielward27/non_trainable
Browse files Browse the repository at this point in the history
Add non_trainable
  • Loading branch information
danielward27 authored Jun 21, 2024
2 parents c067413 + ec338bc commit 19c7355
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
23 changes: 11 additions & 12 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@ FAQ
Freezing parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Often it is useful to not train particular parameters. The easiest way to achieve this
is to use the :class:`flowjax.wrappers.NonTrainable` wrapper class. For example, to
avoid training the base distribution of a transformed distribution:

.. testsetup::

from flowjax.distributions import Normal
flow = Normal()
is to use :func:`flowjax.wrappers.non_trainable`. This will wrap the inexact array
leaves with :class:`flowjax.wrappers.NonTrainable`, which will apply ``stop_gradient``
when unwrapping the parameters. For commonly used distribution and bijection methods,
unwrapping is applied automatically. For example

.. doctest::

>>> from flowjax.distributions import Normal
>>> from flowjax.wrappers import non_trainable
>>> dist = Normal()
>>> dist = non_trainable(dist)

>>> import equinox as eqx
>>> from flowjax.wrappers import NonTrainable
>>> flow = eqx.tree_at(lambda flow: flow.base_dist, flow, replace_fn=NonTrainable)
To mark part of a tree as frozen, use ``non_trainable`` with e.g.
``equinox.tree_at`` or ``jax.tree_map``.

If you wish to avoid training e.g. a specific type, it may be easier to use
``jax.tree_map`` to apply the NonTrainable wrapper as required.

Standardizing variables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
25 changes: 24 additions & 1 deletion flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ def unwrap(self) -> T:
class NonTrainable(AbstractUnwrappable[T]):
"""Applies stop gradient to all arraylike leaves before unwrapping.
See also ``non_trainable``, which is probably a generally prefereable way to achieve
similar behaviour, which wraps the arraylike leaves directly, rather than the tree.
Useful to mark pytrees (arrays, submodules, etc) as frozen/non-trainable. We also
filter out these modules when partitioning parameters for training, or when
filter out NonTrainable nodes when partitioning parameters for training, or when
parameterizing bijections in coupling/masked autoregressive flows (transformers).
"""

Expand All @@ -110,6 +113,26 @@ def unwrap(self) -> T:
return eqx.combine(lax.stop_gradient(differentiable), static)


def non_trainable(tree: PyTree):
"""Freezes parameters by wrapping inexact array leaves with ``NonTrainable``.
Wrapping the arrays rather than the entire tree is often preferable, allowing easier
access to attributes, compared to wrapping the entire tree.
Args:
tree: The pytree.
"""

def _map_fn(leaf):
return NonTrainable(leaf) if eqx.is_inexact_array(leaf) else leaf

return jax.tree_util.tree_map(
f=_map_fn,
tree=tree,
is_leaf=lambda x: isinstance(x, NonTrainable),
)


def _apply_inverse_and_check_valid(bijection, arr):
param_inv = bijection._vectorize.inverse(arr)
return eqx.error_if(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license = { file = "LICENSE" }
name = "flowjax"
readme = "README.md"
requires-python = ">=3.10"
version = "12.3.0"
version = "12.4.0"

[project.urls]
repository = "https://github.com/danielward27/flowjax"
Expand Down
12 changes: 7 additions & 5 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Lambda,
NonTrainable,
WeightNormalization,
non_trainable,
unwrap,
)

Expand Down Expand Up @@ -56,15 +57,16 @@ def test_Lambda():
assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2))


def test_NonTrainable():
dist = Normal()
dist = eqx.tree_at(lambda dist: dist.bijection, dist, replace_fn=NonTrainable)
def test_NonTrainable_and_non_trainable():
dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable)
dist2 = non_trainable(Normal())

def loss(dist, x):
return dist.log_prob(x)

grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]
for dist in [dist1, dist2]:
grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]


def test_WeightNormalization():
Expand Down

0 comments on commit 19c7355

Please sign in to comment.