Open
Description
Hi, thanks for making this lovely package.
I have a simple ODE model with the following code:
import tensorflow as tf
import matplotlib.pyplot as plt
from tfdiffeq import odeint
class Kinetics(tf.keras.Model):
def __init__(self, k1, k2, k3):
super().__init__()
self.k1, self.k2, self.k3 = k1, k2, k3
@tf.function
def call(self, t, y):
s0, s1, s2 = y[0], y[1], y[2]
d_0 = - self.k1 * s0 + self.k2 * s1
d_1 = -self.k2 * s1 - self.k3 * s1 + self.k1 * s0
d_2 = self.k3 * s1
return tf.stack([d_0, d_1, d_2])
with tf.device('/gpu:0'):
tf.keras.backend.set_floatx('float64')
k1 = 1.
k2 = 2.
k3 = 3.
NUM_SAMPLES = 100
t = tf.linspace(0., 10., num=NUM_SAMPLES)
y_init = tf.constant([1.,0., 0.], dtype=tf.float64)
func = Kinetics(k1, k2, k3)
result = odeint(func, y_init, t)
plt.figure()
for r in tf.transpose(result):
plt.plot(r)
plt.show()
This works fine and gives me the expected result directly. However, when I change the parameters to:
k1 = 1.
k2 = 2000.
k3 = 3.
Where now I have a steady-state population, as the population of s1 is almost zero over the time interval
The calculation now takes a very long time, while if I use scipy's odeint
it is still very fast.
Any suggestions on why this is and what I can do to improve this?
I've tried tensorflow 2.1.0 and 2.2.0, adding @tf.function
decorator, and @KuzMenachem 's fork version.
Metadata
Metadata
Assignees
Labels
No labels