Skip to content

[Feature Request] support for storing observables in Diffrax #221

@yonesuke

Description

@yonesuke

I would like to express my gratitude for the wonderful library Diffrax. I have been using it in my research and have found it to be very useful.

In my research, I am interested in systems of differential equations with high degrees of freedom. I have found that storing the state at each time step is extremely inefficient. Instead, I have decided to store observables.

One of my research targets is the Kuramoto model, and I have implemented it using JAX with the following code, including an observable_fn option.

import jax.numpy as jnp
from jax import random, jit
from jax.lax import fori_loop

# Kuramoto model functions
def kuramoto_vector_field(thetas, K, omegas):
    coss, sins = jnp.cos(thetas), jnp.sin(thetas)
    rx, ry = coss.mean(), sins.mean()
    return omegas + K * (ry * coss - rx * sins)
# this function calculates the order parameter of the oscillators
# order parameter represents the centroid of the oscillators moving on a circle
# if the value is small (nearly zero), oscillators are not synchronizing
# if the value is large (close to one), oscillators are in sync!!
def orderparam_fn(thetas):
    return jnp.abs(jnp.mean(jnp.exp(1j * thetas)))

# ODE solver functions
def runge_kutta(func, state, dt):
    k1 = func(state)
    k2 = func(state + 0.5 * dt * k1)
    k3 = func(state + 0.5 * dt * k2)
    k4 = func(state + dt * k3)
    return state + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
def run(func, init_state, dt, t_max, observable_fn):
    update_fn = jit(lambda state: runge_kutta(func, state, dt))
    ts = jnp.arange(0, t_max, dt)
    n_step = ts.shape[0]
    observables = jnp.zeros((n_step,))
    observables = observables.at[0].set(observable_fn(init_state))
    def body_fn(i, val):
        state, observables = val
        new_state = update_fn(state)
        observables = observables.at[i].set(observable_fn(new_state)) # storing observables in stead of state
        return new_state, observables
    _, observables = fori_loop(1, n_step, body_fn, (init_state, observables))
    return ts, observables

# Kuramoto model settings
n_oscillators = 10**3 # number of oscillator could be 10**8 when I'm running the code on GPU
omegas = random.cauchy(random.PRNGKey(0), (n_oscillators,))
K = 4.0

# ODE settings & run!!
dt, t_max = 0.01, 10
init_thetas = random.uniform(random.PRNGKey(1), (n_oscillators,)) * 2 * jnp.pi
ts, orderparams = run(lambda thetas: kuramoto_vector_field(thetas, K, omegas), init_thetas, dt, t_max, orderparam_fn)

I was wondering if it would be possible to achieve this kind of functionality in Diffrax?

Thank you for your time and consideration.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureNew featurenextHigher-priority items

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions