Skip to content

gerlero/parametrix

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

46 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Parametrix logo

flax.nnx.Param-like computed parameters for bare JAX (and Equinox).

Documentation CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

Installation

pip install parametrix

Example

The following example shows how to use Param as a base class for a parameter class that enforces positivity:

import jax.numpy as jnp
from parametrix import Param

class PositiveOnlyParam(Param):
    def __init__(self, value):
        super().__init__(jnp.log(value))

    @property
    def value(self):
        return jnp.exp(self.raw_value)

The backing values of Params are always stored as jax.Arrays, meaning that they will automatically be picked up as learnable parameters by libraries like Equinox.

Param objects also behave like numeric types, so that they are able to be used within models and any other functions without having to make any changes to the code.

Documentation

API documentation is available at Read the Docs.