Skip to content

Commit 4c555b6

Browse files
committed
Setup for dynamical systems
1 parent 932d0e1 commit 4c555b6

File tree

7 files changed

+412
-11
lines changed

7 files changed

+412
-11
lines changed
21.5 KB
Loading

examples/lotka/lotka.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
2+
#%%
3+
import time
4+
5+
## Use jax cpu
6+
# jax.config.update("jax_platform_name", "cpu")
7+
8+
## Limit JAX memory usage
9+
import os
10+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
11+
12+
import jax
13+
import jax.numpy as jnp
14+
15+
import numpy as np
16+
import equinox as eqx
17+
import optax
18+
import matplotlib.pyplot as plt
19+
from functools import partial
20+
import datetime
21+
# from flax.metrics import tensorboard
22+
23+
from nodepint.utils import get_new_keys, sbplot, seconds_to_hours
24+
from nodepint.training import train_project_neural_ode, test_neural_ode
25+
# from nodepint.data import load_jax_dataset, get_dataset_features, preprocess_mnist
26+
from nodepint.data import load_mnist_dataset_torch, load_lotka_volterra_dataset
27+
from nodepint.integrators import dopri_integrator, euler_integrator, rk4_integrator, dopri_integrator_diff, dopri_integrator_diffrax
28+
from nodepint.pint import newton_root_finder, direct_root_finder, fixed_point_finder, direct_root_finder_aug, parareal
29+
from nodepint.sampling import random_sampling, identity_sampling, neural_sampling
30+
31+
import cProfile
32+
33+
import os
34+
print("Available devices:", jax.devices())
35+
import warnings
36+
warnings.filterwarnings("ignore")
37+
38+
SEED = 2026
39+
40+
## Reload the nodepint package before each cell run
41+
%load_ext autoreload
42+
%autoreload 2
43+
44+
#%% [markdown]
45+
# ## Define neural net
46+
47+
#%%
48+
49+
50+
51+
class Encoder(eqx.Module):
52+
# layers: list
53+
54+
#### Use a tensordot, and sum over all the three/two later dimensions of the model.
55+
## If basis = (2,4,4, 1,28,28) and x of shape (1,28,28) then the tensordot will
56+
## will return something of shape (2,4,4)
57+
## Finally I can multiply this by a learned weight of shape (2,4,4) as well
58+
59+
def __init__(self, key=None):
60+
# keys = get_new_keys(key, num=3)
61+
# self.layers = [eqx.nn.Conv2d(1, 64, (3, 3), stride=1, key=keys[0]), jax.nn.relu, eqx.nn.GroupNorm(64, 64),
62+
# eqx.nn.Conv2d(64, 64, (4, 4), stride=2, padding=1, key=keys[1]), jax.nn.relu, eqx.nn.GroupNorm(64, 64),
63+
# eqx.nn.Conv2d(64, 64, (4, 4), stride=2, padding=1, key=keys[2]) ]
64+
pass
65+
66+
def __call__(self, x):
67+
# for layer in self.layers:
68+
# x = layer(x)
69+
return x
70+
71+
72+
class Processor(eqx.Module):
73+
layers: list
74+
75+
def __init__(self, key=None):
76+
keys = get_new_keys(key, num=4)
77+
self.layers = [eqx.nn.Linear(2, 64, key=keys[0]),
78+
jax.nn.softplus,
79+
eqx.nn.Linear(64, 64, key=keys[1]),
80+
jax.nn.softplus,
81+
eqx.nn.Linear(64, 64, key=keys[3]),
82+
jax.nn.softplus,
83+
eqx.nn.Linear(64, 2, key=keys[2])]
84+
85+
def __call__(self, x, t):
86+
# y = jnp.concatenate([jnp.broadcast_to(t, (1,)+x.shape[1:]), x], axis=0)
87+
y = x
88+
for layer in self.layers:
89+
y = layer(y)
90+
return y
91+
92+
93+
class Decoder(eqx.Module):
94+
95+
# layers: list
96+
97+
def __init__(self, key=None):
98+
# key = get_new_keys(key, 1)
99+
# self.layers = [eqx.nn.GroupNorm(64, 64), jax.nn.relu,
100+
# eqx.nn.AvgPool2d((6, 6)), lambda x:jnp.reshape(x, (64,)),
101+
# eqx.nn.Linear(64, 10, key=key)]
102+
# # self.layers = [eqx.nn.GroupNorm(4, 4), jax.nn.relu,
103+
# # eqx.nn.AvgPool2d((6, 6)), lambda x:jnp.reshape(x, (4,)),
104+
# # eqx.nn.Linear(4, 10, key=key)]
105+
pass
106+
107+
def __call__(self, x):
108+
# for layer in self.layers:
109+
# x = layer(x)
110+
return x
111+
112+
113+
114+
115+
#%% [markdown]
116+
# ## Load the dataset
117+
118+
#%%
119+
120+
ds = load_lotka_volterra_dataset(root_dir="./data", split="train")
121+
# ds = make_dataloader_torch(ds, subset_size="all", seed=SEED, norm_factor=255.)
122+
123+
print("Number of training examples:", len(ds))
124+
125+
## Visualise a datapoint
126+
np.random.seed(time.time_ns()%(2**32))
127+
point_id = np.random.randint(0, len(ds))
128+
init_cond, trajectory = ds[point_id]
129+
t_eval = ds.t
130+
131+
## Set figure size
132+
plt.figure(figsize=(8, 4))
133+
plt.title(f"Sample trajectory id: {point_id}")
134+
plt.plot(t_eval, trajectory[:, 0], label="Prey")
135+
plt.plot(t_eval, trajectory[:, 1], label="Predator")
136+
plt.show();
137+
138+
139+
#%% [markdown]
140+
# ## Define training parameters
141+
142+
#%%
143+
144+
## Optax crossentropy loss
145+
optim_scheme = optax.adam
146+
# times = tuple(np.linspace(0, 1, 101).flatten())
147+
times = (t_eval[0], t_eval[-1], t_eval.shape[0]) ## t0, tf, nb_times (this is for solving the ODE if an adaptative time stepper is not used. Not for eval)
148+
149+
integrator_args = (1e-8, 1e-8, jnp.inf, 20, 10, "checkpointed") ## rtol, atol, max_dt, max_steps, kind, max_steps_rev (these are typically by adatative time steppers)
150+
fixed_point_args = (1., 1e-12, 20) ## learning_rate, tol, max_iter
151+
152+
# loss = optax.softmax_cross_entropy
153+
# loss = optax.softmax_cross_entropy_with_integer_labels
154+
loss = optax.l2_loss
155+
156+
keys = get_new_keys(SEED, num=3)
157+
neural_nets = (Encoder(key=keys[0]), Processor(key=keys[1]), Decoder(key=keys[2]))
158+
159+
## PinT scheme with only mandatory arguments
160+
161+
nb_epochs = 5000
162+
batch_size = 4*1
163+
total_steps = nb_epochs*(len(ds)//batch_size)
164+
165+
scheduler = optax.piecewise_constant_schedule(init_value=1e-3, boundaries_and_scales={int(total_steps*0.5):0.75, int(total_steps*0.75):0.75})
166+
167+
168+
key = get_new_keys(SEED)
169+
170+
171+
172+
173+
174+
#%% [markdown]
175+
# ## Train the model
176+
177+
train_params = {"neural_nets":neural_nets,
178+
"data":ds,
179+
# "pint_scheme":fixed_point_finder,
180+
# "pint_scheme":direct_root_finder_aug,
181+
"pint_scheme":parareal,
182+
"samp_scheme":identity_sampling,
183+
# "integrator":rk4_integrator,
184+
"integrator":dopri_integrator_diffrax,
185+
"integrator_args":integrator_args,
186+
"loss_fn":loss,
187+
"optim_scheme":optim_scheme,
188+
"nb_processors":20-1,
189+
"scheduler":scheduler,
190+
"times":times,
191+
"fixed_point_args":fixed_point_args,
192+
"nb_epochs":nb_epochs,
193+
"batch_size":batch_size,
194+
"repeat_projection":1,
195+
"nb_vectors":5,
196+
"force_serial":False,
197+
"key":key}
198+
199+
start_time = time.time()
200+
cpu_start_time = time.process_time()
201+
202+
trained_networks, shooting_fn, loss_hts, errors_hts, nb_iters_hts = train_project_neural_ode(**train_params)
203+
204+
clock_time = time.process_time() - cpu_start_time
205+
wall_time = time.time() - start_time
206+
207+
# print("\nNumber of iterations till PinT eventual convergence:\n", np.asarray(nb_iters_hts))
208+
# print("Errors during PinT iterations:\n", np.asarray(errors_hts))
209+
210+
time_in_hmsecs = seconds_to_hours(wall_time)
211+
print("\nTotal training time: %d hours %d mins %d secs" %time_in_hmsecs)
212+
213+
214+
#%% [markdown]
215+
# ## Analyse loss history
216+
217+
#%%
218+
219+
# ## Plot the loss histories per iterations
220+
# labels = [str(i) for i in range(len(loss_hts))]
221+
# epochs = range(len(loss_hts[0]))
222+
223+
# sbplot(epochs, jnp.stack(loss_hts, axis=-1), label=labels, x_label="epochs", y_scale="log", title="Loss histories");
224+
225+
## Loss histories acros all iterations
226+
total_loss = np.concatenate(loss_hts, axis=0)
227+
total_epochs = 1 + np.arange(len(total_loss))
228+
229+
ax = sbplot(total_epochs, total_loss, x_label="epochs", y_scale="log", title="Total loss history");
230+
231+
## Save the plot
232+
# plt.savefig("loss_history_coarse_euler.png")
233+
234+
235+
236+
237+
#%% [markdown]
238+
# ## Compute metrics on a test dataset
239+
240+
#%%
241+
242+
## Load the test dataset
243+
test_ds = load_lotka_volterra_dataset(root_dir="./data", split="test")
244+
245+
print("\nNumber of testing examples", len(test_ds))
246+
247+
def acc_fn(x, y):
248+
return jnp.mean((x-y)**2)
249+
250+
test_params = {"neural_nets": trained_networks,
251+
"data":test_ds,
252+
"pint_scheme":parareal, ## If None then the fixed_point_ad_rule is used
253+
# "pint_scheme":direct_scheme,
254+
# "integrator":rk4_integrator,
255+
"integrator":dopri_integrator_diffrax,
256+
"integrator_args":integrator_args,
257+
"fixed_point_args":fixed_point_args,
258+
"acc_fn":acc_fn,
259+
"shooting_fn":shooting_fn,
260+
"nb_processors":20-1,
261+
"times":times,
262+
"batch_size":4}
263+
264+
265+
start_time = time.time()
266+
267+
avg_acc = test_neural_ode(**test_params)
268+
269+
print(avg_acc)
270+
271+
test_wall_time = time.time() - start_time
272+
time_in_hms= seconds_to_hours(test_wall_time)
273+
274+
print(f"\nAverage test loss: {avg_acc:.8f}")
275+
print("Test time: %d hours %d mins %d secs" %time_in_hms)
276+
277+
278+
# %%

nodepint/data.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Collection
88
# from datasets import Dataset, load_dataset
99

10+
import torch
1011
from torch.utils.data import DataLoader, Dataset
1112
from torchvision import datasets, transforms
1213

@@ -22,7 +23,64 @@ def load_mnist_dataset_torch(root="./data", train=True) -> Dataset:
2223

2324
return datasets.MNIST(root=root, train=train, transform=transform, download=True)
2425

25-
def make_dataloader_torch(ds: Dataset, batch_size: int = 32, num_workers: int=32, shuffle: bool = True) -> DataLoader:
26+
27+
28+
29+
30+
class ODEDataset(Dataset):
31+
"""ODE dataset."""
32+
33+
def __init__(self, root_dir, split="train"):
34+
"""
35+
Arguments:
36+
root_dir (string): Directory with all the trajectories.
37+
split (string): train, test (for in-domain), or ood_train, ood_test (for out-of-domain)
38+
"""
39+
40+
self.root_dir = root_dir if root_dir[-1] == "/" else root_dir+"/"
41+
self.split = split
42+
43+
if self.split == "train":
44+
filename = self.root_dir+"train.npz"
45+
elif self.split == "test":
46+
filename = self.root_dir+"test.npz"
47+
elif self.split == "ood_train":
48+
filename = self.root_dir+"ood_train.npz"
49+
elif self.split == "ood_test":
50+
filename = self.root_dir+"ood_test.npz"
51+
52+
data = np.load(filename)
53+
self.X, self.t = data["X"][5], data["t"] ## TODO: use any environment e here
54+
55+
def __len__(self):
56+
return self.X.shape[0] ## Number of trajectories
57+
58+
def __getitem__(self, idx):
59+
if torch.is_tensor(idx):
60+
idx = idx.tolist()
61+
62+
# sample = {'init_state': self.X[idx, 0, ...], 'trajectory': self.X[idx, :, ...]}
63+
sample = (self.X[idx, 0, ...], self.X[idx, :, ...])
64+
65+
return sample
66+
67+
68+
69+
70+
def load_lotka_volterra_dataset(root_dir, split) -> Dataset:
71+
return ODEDataset(root_dir, split)
72+
73+
74+
75+
76+
77+
78+
79+
80+
81+
82+
83+
def make_dataloader_torch(ds: Dataset, batch_size: int = 32, num_workers: int=24, shuffle: bool = True) -> DataLoader:
2684
return DataLoader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
2785

2886

@@ -31,6 +89,10 @@ def make_dataloader_torch(ds: Dataset, batch_size: int = 32, num_workers: int=32
3189

3290

3391

92+
93+
94+
95+
3496
# def load_jax_dataset_hf(**kwargs) -> Dataset:
3597
# """
3698
# The load_jax_dataset function loads a dataset from the HuggingFace Hub or from the file system, converts it to JAX format, and returns it.

0 commit comments

Comments
 (0)