Skip to content

Poor performance with a 'steady-state' population #6

Open
@Jhsmit

Description

@Jhsmit

Hi, thanks for making this lovely package.

I have a simple ODE model with the following code:

import tensorflow as tf
import matplotlib.pyplot as plt
from tfdiffeq import odeint


class Kinetics(tf.keras.Model):

    def __init__(self, k1, k2, k3):
        super().__init__()
        self.k1, self.k2, self.k3 = k1, k2, k3

    @tf.function
    def call(self, t, y):
        s0, s1, s2 = y[0], y[1], y[2]

        d_0 = - self.k1 * s0 + self.k2 * s1
        d_1 = -self.k2 * s1 - self.k3 * s1 + self.k1 * s0
        d_2 = self.k3 * s1

        return tf.stack([d_0, d_1, d_2])


with tf.device('/gpu:0'):
    tf.keras.backend.set_floatx('float64')
    k1 = 1.
    k2 = 2.
    k3 = 3.
    NUM_SAMPLES = 100

    t = tf.linspace(0., 10., num=NUM_SAMPLES)
    y_init = tf.constant([1.,0., 0.], dtype=tf.float64)

    func = Kinetics(k1, k2, k3)
    result = odeint(func, y_init, t)


plt.figure()
for r in tf.transpose(result):
    plt.plot(r)
plt.show()

This works fine and gives me the expected result directly. However, when I change the parameters to:

    k1 = 1.
    k2 = 2000.
    k3 = 3.

Where now I have a steady-state population, as the population of s1 is almost zero over the time interval
The calculation now takes a very long time, while if I use scipy's odeint it is still very fast.

Any suggestions on why this is and what I can do to improve this?

I've tried tensorflow 2.1.0 and 2.2.0, adding @tf.function decorator, and @KuzMenachem 's fork version.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions