Skip to content

Commit 476f83f

Browse files
committed
Update notebook
1 parent 11795ef commit 476f83f

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

autoemulate/experimental/exploratory/the_well/the_well_models_and_metrics.ipynb

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"from the_well.data import WellDataModule\n",
1515
"import matplotlib.pyplot as plt\n",
1616
"import logging\n",
17-
"from autoemulate.experimental.emulators.the_well import TheWellEmulator, TheWellFNO\n",
17+
"from autoemulate.experimental.emulators.the_well import TheWellFNO\n",
1818
"\n",
1919
"logger = logging.getLogger()\n",
2020
"\n",
@@ -66,7 +66,7 @@
6666
"outputs": [],
6767
"source": [
6868
"# Fit the model\n",
69-
"em._fit(None, None)"
69+
"# em._fit(None, None)"
7070
]
7171
},
7272
{
@@ -89,7 +89,7 @@
8989
"source": [
9090
"# Example of predictions with the rollout model from a single batch\n",
9191
"batch = next(iter(datamodule.val_dataloader()))\n",
92-
"em.trainer.rollout_model(em.trainer.model, batch, em.trainer.formatter, train=False)[0].shape"
92+
"batch_pred = em.trainer.rollout_model(em.trainer.model, batch, em.trainer.formatter, train=False)[0]"
9393
]
9494
},
9595
{
@@ -99,7 +99,7 @@
9999
"metadata": {},
100100
"outputs": [],
101101
"source": [
102-
"batch = next(iter(datamodule.val_dataloader()))"
102+
"batch_pred.shape"
103103
]
104104
},
105105
{
@@ -109,7 +109,9 @@
109109
"metadata": {},
110110
"outputs": [],
111111
"source": [
112-
"batch[\"input_fields\"].shape, batch[\"output_fields\"].shape, "
112+
"batch = next(iter(datamodule.val_dataloader()))\n",
113+
"batch[\"input_fields\"].shape, batch[\"output_fields\"].shape\n",
114+
"# (torch.Size([4, 1, 128, 384, 4]), torch.Size([4, 1, 128, 384, 4]))"
113115
]
114116
},
115117
{
@@ -119,7 +121,8 @@
119121
"metadata": {},
120122
"outputs": [],
121123
"source": [
122-
"batch_rollout = next(iter(datamodule.rollout_val_dataloader()))"
124+
"batch_rollout = next(iter(datamodule.rollout_val_dataloader()))\n",
125+
"batch_rollout_pred = em.trainer.rollout_model(em.trainer.model, batch_rollout, em.trainer.formatter, train=False)[0]"
123126
]
124127
},
125128
{
@@ -129,7 +132,11 @@
129132
"metadata": {},
130133
"outputs": [],
131134
"source": [
132-
"batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape, \n"
135+
"# Output is the whole time series from the \n",
136+
"print(batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape)\n",
137+
"# (torch.Size([1, 1, 128, 384, 4]), torch.Size([1, 100, 128, 384, 4]))\n",
138+
"print(batch_rollout_pred.shape)\n",
139+
"# torch.Size([1, 10, 128, 384, 4]) : max_rollout_steps (10) from current input"
133140
]
134141
},
135142
{
@@ -139,7 +146,31 @@
139146
"metadata": {},
140147
"outputs": [],
141148
"source": [
142-
"em.trainer.max_rollout_steps"
149+
"# Iteration over the batches covers the whole validation set since it is based from the\n",
150+
"# initial conditions in the validation set\n",
151+
"for i, batch in enumerate(datamodule.rollout_val_dataloader()):\n",
152+
" print(f\"Batch {i}\")\n",
153+
" # Output is the whole time series from the \n",
154+
" print(batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape)\n",
155+
" # (torch.Size([1, 1, 128, 384, 4]), torch.Size([1, 100, 128, 384, 4]))\n",
156+
" batch_rollout_pred = em.trainer.rollout_model(em.trainer.model, batch_rollout, em.trainer.formatter, train=False)[0]\n",
157+
" print(batch_rollout_pred.shape)\n",
158+
" # torch.Size([1, 10, 128, 384, 4]) : max_rollout_steps (10) from current input "
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": null,
164+
"id": "11",
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"# Validation loop\n",
169+
"em.trainer.validation_loop(\n",
170+
" datamodule.rollout_val_dataloader(),\n",
171+
" valid_or_test=\"rollout_valid\",\n",
172+
" full=True\n",
173+
")"
143174
]
144175
}
145176
],

0 commit comments

Comments
 (0)