-
-
Notifications
You must be signed in to change notification settings - Fork 156
Closed
Labels
Description
I would like to express my gratitude for the wonderful library Diffrax. I have been using it in my research and have found it to be very useful.
In my research, I am interested in systems of differential equations with high degrees of freedom. I have found that storing the state at each time step is extremely inefficient. Instead, I have decided to store observables.
One of my research targets is the Kuramoto model, and I have implemented it using JAX with the following code, including an observable_fn
option.
import jax.numpy as jnp
from jax import random, jit
from jax.lax import fori_loop
# Kuramoto model functions
def kuramoto_vector_field(thetas, K, omegas):
coss, sins = jnp.cos(thetas), jnp.sin(thetas)
rx, ry = coss.mean(), sins.mean()
return omegas + K * (ry * coss - rx * sins)
# this function calculates the order parameter of the oscillators
# order parameter represents the centroid of the oscillators moving on a circle
# if the value is small (nearly zero), oscillators are not synchronizing
# if the value is large (close to one), oscillators are in sync!!
def orderparam_fn(thetas):
return jnp.abs(jnp.mean(jnp.exp(1j * thetas)))
# ODE solver functions
def runge_kutta(func, state, dt):
k1 = func(state)
k2 = func(state + 0.5 * dt * k1)
k3 = func(state + 0.5 * dt * k2)
k4 = func(state + dt * k3)
return state + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
def run(func, init_state, dt, t_max, observable_fn):
update_fn = jit(lambda state: runge_kutta(func, state, dt))
ts = jnp.arange(0, t_max, dt)
n_step = ts.shape[0]
observables = jnp.zeros((n_step,))
observables = observables.at[0].set(observable_fn(init_state))
def body_fn(i, val):
state, observables = val
new_state = update_fn(state)
observables = observables.at[i].set(observable_fn(new_state)) # storing observables in stead of state
return new_state, observables
_, observables = fori_loop(1, n_step, body_fn, (init_state, observables))
return ts, observables
# Kuramoto model settings
n_oscillators = 10**3 # number of oscillator could be 10**8 when I'm running the code on GPU
omegas = random.cauchy(random.PRNGKey(0), (n_oscillators,))
K = 4.0
# ODE settings & run!!
dt, t_max = 0.01, 10
init_thetas = random.uniform(random.PRNGKey(1), (n_oscillators,)) * 2 * jnp.pi
ts, orderparams = run(lambda thetas: kuramoto_vector_field(thetas, K, omegas), init_thetas, dt, t_max, orderparam_fn)
I was wondering if it would be possible to achieve this kind of functionality in Diffrax?
Thank you for your time and consideration.