|
14 | 14 | "from the_well.data import WellDataModule\n", |
15 | 15 | "import matplotlib.pyplot as plt\n", |
16 | 16 | "import logging\n", |
17 | | - "from autoemulate.experimental.emulators.the_well import TheWellEmulator, TheWellFNO\n", |
| 17 | + "from autoemulate.experimental.emulators.the_well import TheWellFNO\n", |
18 | 18 | "\n", |
19 | 19 | "logger = logging.getLogger()\n", |
20 | 20 | "\n", |
|
66 | 66 | "outputs": [], |
67 | 67 | "source": [ |
68 | 68 | "# Fit the model\n", |
69 | | - "em._fit(None, None)" |
| 69 | + "# em._fit(None, None)" |
70 | 70 | ] |
71 | 71 | }, |
72 | 72 | { |
|
89 | 89 | "source": [ |
90 | 90 | "# Example of predictions with the rollout model from a single batch\n", |
91 | 91 | "batch = next(iter(datamodule.val_dataloader()))\n", |
92 | | - "em.trainer.rollout_model(em.trainer.model, batch, em.trainer.formatter, train=False)[0].shape" |
| 92 | + "batch_pred = em.trainer.rollout_model(em.trainer.model, batch, em.trainer.formatter, train=False)[0]" |
93 | 93 | ] |
94 | 94 | }, |
95 | 95 | { |
|
99 | 99 | "metadata": {}, |
100 | 100 | "outputs": [], |
101 | 101 | "source": [ |
102 | | - "batch = next(iter(datamodule.val_dataloader()))" |
| 102 | + "batch_pred.shape" |
103 | 103 | ] |
104 | 104 | }, |
105 | 105 | { |
|
109 | 109 | "metadata": {}, |
110 | 110 | "outputs": [], |
111 | 111 | "source": [ |
112 | | - "batch[\"input_fields\"].shape, batch[\"output_fields\"].shape, " |
| 112 | + "batch = next(iter(datamodule.val_dataloader()))\n", |
| 113 | + "batch[\"input_fields\"].shape, batch[\"output_fields\"].shape\n", |
| 114 | + "# (torch.Size([4, 1, 128, 384, 4]), torch.Size([4, 1, 128, 384, 4]))" |
113 | 115 | ] |
114 | 116 | }, |
115 | 117 | { |
|
119 | 121 | "metadata": {}, |
120 | 122 | "outputs": [], |
121 | 123 | "source": [ |
122 | | - "batch_rollout = next(iter(datamodule.rollout_val_dataloader()))" |
| 124 | + "batch_rollout = next(iter(datamodule.rollout_val_dataloader()))\n", |
| 125 | + "batch_rollout_pred = em.trainer.rollout_model(em.trainer.model, batch_rollout, em.trainer.formatter, train=False)[0]" |
123 | 126 | ] |
124 | 127 | }, |
125 | 128 | { |
|
129 | 132 | "metadata": {}, |
130 | 133 | "outputs": [], |
131 | 134 | "source": [ |
132 | | - "batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape, \n" |
| 135 | + "# Output is the whole time series from the \n", |
| 136 | + "print(batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape)\n", |
| 137 | + "# (torch.Size([1, 1, 128, 384, 4]), torch.Size([1, 100, 128, 384, 4]))\n", |
| 138 | + "print(batch_rollout_pred.shape)\n", |
| 139 | + "# torch.Size([1, 10, 128, 384, 4]) : max_rollout_steps (10) from current input" |
133 | 140 | ] |
134 | 141 | }, |
135 | 142 | { |
|
139 | 146 | "metadata": {}, |
140 | 147 | "outputs": [], |
141 | 148 | "source": [ |
142 | | - "em.trainer.max_rollout_steps" |
| 149 | + "# Iteration over the batches covers the whole validation set since it is based from the\n", |
| 150 | + "# initial conditions in the validation set\n", |
| 151 | + "for i, batch in enumerate(datamodule.rollout_val_dataloader()):\n", |
| 152 | + " print(f\"Batch {i}\")\n", |
| 153 | + " # Output is the whole time series from the \n", |
| 154 | + " print(batch_rollout[\"input_fields\"].shape, batch_rollout[\"output_fields\"].shape)\n", |
| 155 | + " # (torch.Size([1, 1, 128, 384, 4]), torch.Size([1, 100, 128, 384, 4]))\n", |
| 156 | + " batch_rollout_pred = em.trainer.rollout_model(em.trainer.model, batch_rollout, em.trainer.formatter, train=False)[0]\n", |
| 157 | + " print(batch_rollout_pred.shape)\n", |
| 158 | + " # torch.Size([1, 10, 128, 384, 4]) : max_rollout_steps (10) from current input " |
| 159 | + ] |
| 160 | + }, |
| 161 | + { |
| 162 | + "cell_type": "code", |
| 163 | + "execution_count": null, |
| 164 | + "id": "11", |
| 165 | + "metadata": {}, |
| 166 | + "outputs": [], |
| 167 | + "source": [ |
| 168 | + "# Validation loop\n", |
| 169 | + "em.trainer.validation_loop(\n", |
| 170 | + " datamodule.rollout_val_dataloader(),\n", |
| 171 | + " valid_or_test=\"rollout_valid\",\n", |
| 172 | + " full=True\n", |
| 173 | + ")" |
143 | 174 | ] |
144 | 175 | } |
145 | 176 | ], |
|
0 commit comments