|
| 1 | + |
| 2 | +#%% |
| 3 | +import time |
| 4 | + |
| 5 | +## Use jax cpu |
| 6 | +# jax.config.update("jax_platform_name", "cpu") |
| 7 | + |
| 8 | +## Limit JAX memory usage |
| 9 | +import os |
| 10 | +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" |
| 11 | + |
| 12 | +import jax |
| 13 | +import jax.numpy as jnp |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import equinox as eqx |
| 17 | +import optax |
| 18 | +import matplotlib.pyplot as plt |
| 19 | +from functools import partial |
| 20 | +import datetime |
| 21 | +# from flax.metrics import tensorboard |
| 22 | + |
| 23 | +from nodepint.utils import get_new_keys, sbplot, seconds_to_hours |
| 24 | +from nodepint.training import train_project_neural_ode, test_neural_ode |
| 25 | +# from nodepint.data import load_jax_dataset, get_dataset_features, preprocess_mnist |
| 26 | +from nodepint.data import load_mnist_dataset_torch, load_lotka_volterra_dataset |
| 27 | +from nodepint.integrators import dopri_integrator, euler_integrator, rk4_integrator, dopri_integrator_diff, dopri_integrator_diffrax |
| 28 | +from nodepint.pint import newton_root_finder, direct_root_finder, fixed_point_finder, direct_root_finder_aug, parareal |
| 29 | +from nodepint.sampling import random_sampling, identity_sampling, neural_sampling |
| 30 | + |
| 31 | +import cProfile |
| 32 | + |
| 33 | +import os |
| 34 | +print("Available devices:", jax.devices()) |
| 35 | +import warnings |
| 36 | +warnings.filterwarnings("ignore") |
| 37 | + |
| 38 | +SEED = 2026 |
| 39 | + |
| 40 | +## Reload the nodepint package before each cell run |
| 41 | +%load_ext autoreload |
| 42 | +%autoreload 2 |
| 43 | + |
| 44 | +#%% [markdown] |
| 45 | +# ## Define neural net |
| 46 | + |
| 47 | +#%% |
| 48 | + |
| 49 | + |
| 50 | + |
| 51 | +class Encoder(eqx.Module): |
| 52 | + # layers: list |
| 53 | + |
| 54 | + #### Use a tensordot, and sum over all the three/two later dimensions of the model. |
| 55 | + ## If basis = (2,4,4, 1,28,28) and x of shape (1,28,28) then the tensordot will |
| 56 | + ## will return something of shape (2,4,4) |
| 57 | + ## Finally I can multiply this by a learned weight of shape (2,4,4) as well |
| 58 | + |
| 59 | + def __init__(self, key=None): |
| 60 | + # keys = get_new_keys(key, num=3) |
| 61 | + # self.layers = [eqx.nn.Conv2d(1, 64, (3, 3), stride=1, key=keys[0]), jax.nn.relu, eqx.nn.GroupNorm(64, 64), |
| 62 | + # eqx.nn.Conv2d(64, 64, (4, 4), stride=2, padding=1, key=keys[1]), jax.nn.relu, eqx.nn.GroupNorm(64, 64), |
| 63 | + # eqx.nn.Conv2d(64, 64, (4, 4), stride=2, padding=1, key=keys[2]) ] |
| 64 | + pass |
| 65 | + |
| 66 | + def __call__(self, x): |
| 67 | + # for layer in self.layers: |
| 68 | + # x = layer(x) |
| 69 | + return x |
| 70 | + |
| 71 | + |
| 72 | +class Processor(eqx.Module): |
| 73 | + layers: list |
| 74 | + |
| 75 | + def __init__(self, key=None): |
| 76 | + keys = get_new_keys(key, num=4) |
| 77 | + self.layers = [eqx.nn.Linear(2, 64, key=keys[0]), |
| 78 | + jax.nn.softplus, |
| 79 | + eqx.nn.Linear(64, 64, key=keys[1]), |
| 80 | + jax.nn.softplus, |
| 81 | + eqx.nn.Linear(64, 64, key=keys[3]), |
| 82 | + jax.nn.softplus, |
| 83 | + eqx.nn.Linear(64, 2, key=keys[2])] |
| 84 | + |
| 85 | + def __call__(self, x, t): |
| 86 | + # y = jnp.concatenate([jnp.broadcast_to(t, (1,)+x.shape[1:]), x], axis=0) |
| 87 | + y = x |
| 88 | + for layer in self.layers: |
| 89 | + y = layer(y) |
| 90 | + return y |
| 91 | + |
| 92 | + |
| 93 | +class Decoder(eqx.Module): |
| 94 | + |
| 95 | + # layers: list |
| 96 | + |
| 97 | + def __init__(self, key=None): |
| 98 | + # key = get_new_keys(key, 1) |
| 99 | + # self.layers = [eqx.nn.GroupNorm(64, 64), jax.nn.relu, |
| 100 | + # eqx.nn.AvgPool2d((6, 6)), lambda x:jnp.reshape(x, (64,)), |
| 101 | + # eqx.nn.Linear(64, 10, key=key)] |
| 102 | + # # self.layers = [eqx.nn.GroupNorm(4, 4), jax.nn.relu, |
| 103 | + # # eqx.nn.AvgPool2d((6, 6)), lambda x:jnp.reshape(x, (4,)), |
| 104 | + # # eqx.nn.Linear(4, 10, key=key)] |
| 105 | + pass |
| 106 | + |
| 107 | + def __call__(self, x): |
| 108 | + # for layer in self.layers: |
| 109 | + # x = layer(x) |
| 110 | + return x |
| 111 | + |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | +#%% [markdown] |
| 116 | +# ## Load the dataset |
| 117 | + |
| 118 | +#%% |
| 119 | + |
| 120 | +ds = load_lotka_volterra_dataset(root_dir="./data", split="train") |
| 121 | +# ds = make_dataloader_torch(ds, subset_size="all", seed=SEED, norm_factor=255.) |
| 122 | + |
| 123 | +print("Number of training examples:", len(ds)) |
| 124 | + |
| 125 | +## Visualise a datapoint |
| 126 | +np.random.seed(time.time_ns()%(2**32)) |
| 127 | +point_id = np.random.randint(0, len(ds)) |
| 128 | +init_cond, trajectory = ds[point_id] |
| 129 | +t_eval = ds.t |
| 130 | + |
| 131 | +## Set figure size |
| 132 | +plt.figure(figsize=(8, 4)) |
| 133 | +plt.title(f"Sample trajectory id: {point_id}") |
| 134 | +plt.plot(t_eval, trajectory[:, 0], label="Prey") |
| 135 | +plt.plot(t_eval, trajectory[:, 1], label="Predator") |
| 136 | +plt.show(); |
| 137 | + |
| 138 | + |
| 139 | +#%% [markdown] |
| 140 | +# ## Define training parameters |
| 141 | + |
| 142 | +#%% |
| 143 | + |
| 144 | +## Optax crossentropy loss |
| 145 | +optim_scheme = optax.adam |
| 146 | +# times = tuple(np.linspace(0, 1, 101).flatten()) |
| 147 | +times = (t_eval[0], t_eval[-1], t_eval.shape[0]) ## t0, tf, nb_times (this is for solving the ODE if an adaptative time stepper is not used. Not for eval) |
| 148 | + |
| 149 | +integrator_args = (1e-8, 1e-8, jnp.inf, 20, 10, "checkpointed") ## rtol, atol, max_dt, max_steps, kind, max_steps_rev (these are typically by adatative time steppers) |
| 150 | +fixed_point_args = (1., 1e-12, 20) ## learning_rate, tol, max_iter |
| 151 | + |
| 152 | +# loss = optax.softmax_cross_entropy |
| 153 | +# loss = optax.softmax_cross_entropy_with_integer_labels |
| 154 | +loss = optax.l2_loss |
| 155 | + |
| 156 | +keys = get_new_keys(SEED, num=3) |
| 157 | +neural_nets = (Encoder(key=keys[0]), Processor(key=keys[1]), Decoder(key=keys[2])) |
| 158 | + |
| 159 | +## PinT scheme with only mandatory arguments |
| 160 | + |
| 161 | +nb_epochs = 5000 |
| 162 | +batch_size = 4*1 |
| 163 | +total_steps = nb_epochs*(len(ds)//batch_size) |
| 164 | + |
| 165 | +scheduler = optax.piecewise_constant_schedule(init_value=1e-3, boundaries_and_scales={int(total_steps*0.5):0.75, int(total_steps*0.75):0.75}) |
| 166 | + |
| 167 | + |
| 168 | +key = get_new_keys(SEED) |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | + |
| 173 | + |
| 174 | +#%% [markdown] |
| 175 | +# ## Train the model |
| 176 | + |
| 177 | +train_params = {"neural_nets":neural_nets, |
| 178 | + "data":ds, |
| 179 | + # "pint_scheme":fixed_point_finder, |
| 180 | + # "pint_scheme":direct_root_finder_aug, |
| 181 | + "pint_scheme":parareal, |
| 182 | + "samp_scheme":identity_sampling, |
| 183 | + # "integrator":rk4_integrator, |
| 184 | + "integrator":dopri_integrator_diffrax, |
| 185 | + "integrator_args":integrator_args, |
| 186 | + "loss_fn":loss, |
| 187 | + "optim_scheme":optim_scheme, |
| 188 | + "nb_processors":20-1, |
| 189 | + "scheduler":scheduler, |
| 190 | + "times":times, |
| 191 | + "fixed_point_args":fixed_point_args, |
| 192 | + "nb_epochs":nb_epochs, |
| 193 | + "batch_size":batch_size, |
| 194 | + "repeat_projection":1, |
| 195 | + "nb_vectors":5, |
| 196 | + "force_serial":False, |
| 197 | + "key":key} |
| 198 | + |
| 199 | +start_time = time.time() |
| 200 | +cpu_start_time = time.process_time() |
| 201 | + |
| 202 | +trained_networks, shooting_fn, loss_hts, errors_hts, nb_iters_hts = train_project_neural_ode(**train_params) |
| 203 | + |
| 204 | +clock_time = time.process_time() - cpu_start_time |
| 205 | +wall_time = time.time() - start_time |
| 206 | + |
| 207 | +# print("\nNumber of iterations till PinT eventual convergence:\n", np.asarray(nb_iters_hts)) |
| 208 | +# print("Errors during PinT iterations:\n", np.asarray(errors_hts)) |
| 209 | + |
| 210 | +time_in_hmsecs = seconds_to_hours(wall_time) |
| 211 | +print("\nTotal training time: %d hours %d mins %d secs" %time_in_hmsecs) |
| 212 | + |
| 213 | + |
| 214 | +#%% [markdown] |
| 215 | +# ## Analyse loss history |
| 216 | + |
| 217 | +#%% |
| 218 | + |
| 219 | +# ## Plot the loss histories per iterations |
| 220 | +# labels = [str(i) for i in range(len(loss_hts))] |
| 221 | +# epochs = range(len(loss_hts[0])) |
| 222 | + |
| 223 | +# sbplot(epochs, jnp.stack(loss_hts, axis=-1), label=labels, x_label="epochs", y_scale="log", title="Loss histories"); |
| 224 | + |
| 225 | +## Loss histories acros all iterations |
| 226 | +total_loss = np.concatenate(loss_hts, axis=0) |
| 227 | +total_epochs = 1 + np.arange(len(total_loss)) |
| 228 | + |
| 229 | +ax = sbplot(total_epochs, total_loss, x_label="epochs", y_scale="log", title="Total loss history"); |
| 230 | + |
| 231 | +## Save the plot |
| 232 | +# plt.savefig("loss_history_coarse_euler.png") |
| 233 | + |
| 234 | + |
| 235 | + |
| 236 | + |
| 237 | +#%% [markdown] |
| 238 | +# ## Compute metrics on a test dataset |
| 239 | + |
| 240 | +#%% |
| 241 | + |
| 242 | +## Load the test dataset |
| 243 | +test_ds = load_lotka_volterra_dataset(root_dir="./data", split="test") |
| 244 | + |
| 245 | +print("\nNumber of testing examples", len(test_ds)) |
| 246 | + |
| 247 | +def acc_fn(x, y): |
| 248 | + return jnp.mean((x-y)**2) |
| 249 | + |
| 250 | +test_params = {"neural_nets": trained_networks, |
| 251 | + "data":test_ds, |
| 252 | + "pint_scheme":parareal, ## If None then the fixed_point_ad_rule is used |
| 253 | + # "pint_scheme":direct_scheme, |
| 254 | + # "integrator":rk4_integrator, |
| 255 | + "integrator":dopri_integrator_diffrax, |
| 256 | + "integrator_args":integrator_args, |
| 257 | + "fixed_point_args":fixed_point_args, |
| 258 | + "acc_fn":acc_fn, |
| 259 | + "shooting_fn":shooting_fn, |
| 260 | + "nb_processors":20-1, |
| 261 | + "times":times, |
| 262 | + "batch_size":4} |
| 263 | + |
| 264 | + |
| 265 | +start_time = time.time() |
| 266 | + |
| 267 | +avg_acc = test_neural_ode(**test_params) |
| 268 | + |
| 269 | +print(avg_acc) |
| 270 | + |
| 271 | +test_wall_time = time.time() - start_time |
| 272 | +time_in_hms= seconds_to_hours(test_wall_time) |
| 273 | + |
| 274 | +print(f"\nAverage test loss: {avg_acc:.8f}") |
| 275 | +print("Test time: %d hours %d mins %d secs" %time_in_hms) |
| 276 | + |
| 277 | + |
| 278 | +# %% |
0 commit comments