Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions brainunit/autograd/_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ def hessian(
Physical unit-aware version of `jax.hessian <https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html>`_,
computing Hessian of ``fun`` as a dense array.

Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def scalar_function1(x):
... return x ** 2 + 3 * x * u.ms + 2 * u.msecond2
>>> hess_fn = u.autograd.hessian(scalar_function1)
>>> hess_fn(jnp.array(1.0) * u.ms)
[2]
>>> def scalar_function2(x):
... return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3
>>> hess_fn = u.autograd.hessian(scalar_function2)
>>> hess_fn(jnp.array(1.0) * u.ms)
[6] * ms

Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
Expand Down
38 changes: 38 additions & 0 deletions brainunit/autograd/_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ def jacrev(
"""
Physical unit-aware version of `jax.jacrev <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html>`_.

Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function1(x):
... return x ** 2
>>> jac_fn = u.autograd.jacrev(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function2(x, y):
... return x * y
>>> jac_fn = u.autograd.jacrev(simple_function2, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)

Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down Expand Up @@ -154,6 +173,25 @@ def jacfwd(
"""
Physical unit-aware version of `jax.jacfwd <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html>`_.

Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> jac_fn = u.autograd.jacfwd(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function(x, y):
... return x * y
>>> jac_fn = u.autograd.jacfwd(simple_function, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)

Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down
31 changes: 29 additions & 2 deletions brainunit/autograd/_vector_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,35 @@ def vector_grad(
unit_aware: bool = True,
):
"""
Compute the gradient of a vector with respect to the input.
"""
Unit-aware compute the gradient of a vector with respect to the input.

Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> vector_grad_fn = u.autograd.vector_grad(simple_function)
>>> vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
[6.0, 8.0] * ms
>>> vector_grad_fn = u.autograd.vector_grad(simple_function, return_value=True)
>>> grad, value = vector_grad_fn(jnp.array([3.0, 4.0]) * u.ms)
>>> grad
[6.0, 8.0] * ms
>>> value
[9.0, 16.0] * ms ** 2

Args:
fun: A Python callable that computes a scalar loss given arguments.
argnums: Optional, an integer or a tuple of integers. The argument number(s) to differentiate with respect to.
return_value: Optional, bool. Whether to return the value of the function.
has_aux: Optional, whether `fun` returns auxiliary data.
unit_aware: Optional, whether to enable unit-aware computation.

Returns:
A function that computes the gradient of `fun` with respect to
the argument(s) indicated by `argnums`.
"""

_check_callable(func)

@wraps(func)
Expand Down
Loading