Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
writing first pre-prepared dataset! #57
Browse files Browse the repository at this point in the history
  • Loading branch information
JackKelly committed Jul 20, 2021
1 parent f33d93c commit 5bd1448
Show file tree
Hide file tree
Showing 7 changed files with 665 additions and 621 deletions.
607 changes: 0 additions & 607 deletions notebooks/nwp_rechunk_experiments.ipynb

This file was deleted.

638 changes: 638 additions & 0 deletions notebooks/pre_process_dataset.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DataSource:
"""
history_len: int
forecast_len: int
convert_to_numpy: bool

def __post_init__(self):
assert self.history_len >= 0
Expand Down Expand Up @@ -66,7 +67,8 @@ def get_batch(
zipped = zip(t0_datetimes, x_locations, y_locations)
for t0_datetime, x_location, y_location in zipped:
example = self.get_example(t0_datetime, x_location, y_location)
example = to_numpy(example)
if self.convert_to_numpy:
example = to_numpy(example)
examples.append(example)

return examples
Expand Down
3 changes: 2 additions & 1 deletion nowcasting_dataset/data_sources/nwp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def get_batch(
selected_data = self._post_process_example(selected_data, t0_dt)

example = self._put_data_into_example(selected_data)
example = to_numpy(example)
if self.convert_to_numpy:
example = to_numpy(example)
examples.append(example)
return examples

Expand Down
3 changes: 2 additions & 1 deletion nowcasting_dataset/data_sources/satellite_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def get_batch(
example = self.get_example(t0_datetime, x_location, y_location)
examples.append(example)

examples = [to_numpy(example) for example in examples]
if self.convert_to_numpy:
examples = [to_numpy(example) for example in examples]
self._cache = {}
return examples

Expand Down
26 changes: 17 additions & 9 deletions nowcasting_dataset/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, Iterable, Dict
from typing import Union, Optional, Iterable, Dict, Callable
from pathlib import Path
import pandas as pd
from copy import deepcopy
Expand Down Expand Up @@ -28,19 +28,22 @@ class NowcastingDataModule(pl.LightningDataModule):
pv_power_filename: Optional[Union[str, Path]] = None
pv_metadata_filename: Optional[Union[str, Path]] = None
batch_size: int = 8
n_training_batches_per_epoch: int = 2048
history_len: int = 2 #: Number of timesteps of history, not including t0.
forecast_len: int = 12 #: Number of timesteps of forecast, not including t0.
sat_filename: Union[str, Path] = consts.SAT_FILENAME
sat_channels: Iterable[str] = ('HRV', )
nwp_base_path: Optional[str] = None
nwp_channels: Optional[Iterable[str]] = (
't', 'dswrf', 'prate', 'r', 'sde', 'si10', 'vis', 'lcc', 'mcc', 'hcc')
image_size_pixels: int = 128
meters_per_pixel: int = 2000
image_size_pixels: int = 128 #: Passed to Data Sources.
meters_per_pixel: int = 2000 #: Passed to Data Sources.
convert_to_numpy: bool = True #: Passed to Data Sources.
pin_memory: bool = True #: Passed to DataLoader.
num_workers: int = 16 #: Passed to DataLoader.
prefetch_factor: int = 64 #: Passed to DataLoader.
n_samples_per_timestep: int = 2 #: Passed to NowcastingDataset
collate_fn: Callable = torch.utils.data._utils.collate.default_collate #: Passed to NowcastingDataset

def __post_init__(self):
super().__init__()
Expand All @@ -62,7 +65,8 @@ def prepare_data(self) -> None:
history_len=self.history_len,
forecast_len=self.forecast_len,
channels=self.sat_channels,
n_timesteps_per_batch=n_timesteps_per_batch)
n_timesteps_per_batch=n_timesteps_per_batch,
convert_to_numpy=self.convert_to_numpy)

self.data_sources = [self.sat_data_source]

Expand All @@ -76,7 +80,8 @@ def prepare_data(self) -> None:
start_dt=sat_datetimes[0],
end_dt=sat_datetimes[-1],
history_len=self.history_len,
forecast_len=self.forecast_len)
forecast_len=self.forecast_len,
convert_to_numpy=self.convert_to_numpy)

self.data_sources = [self.pv_data_source, self.sat_data_source]

Expand All @@ -89,13 +94,15 @@ def prepare_data(self) -> None:
history_len=self.history_len,
forecast_len=self.forecast_len,
channels=self.nwp_channels,
n_timesteps_per_batch=n_timesteps_per_batch)
n_timesteps_per_batch=n_timesteps_per_batch,
convert_to_numpy=self.convert_to_numpy)

self.data_sources.append(self.nwp_data_source)

self.datetime_data_source = data_sources.DatetimeDataSource(
history_len=self.history_len,
forecast_len=self.forecast_len)
forecast_len=self.forecast_len,
convert_to_numpy=self.convert_to_numpy)
self.data_sources.append(self.datetime_data_source)

def setup(self, stage='fit'):
Expand Down Expand Up @@ -151,7 +158,7 @@ def setup(self, stage='fit'):
self.train_dataset = dataset.NowcastingDataset(
t0_datetimes=self.train_t0_datetimes,
data_sources=self.data_sources,
n_batches_per_epoch_per_worker=self._n_batches_per_epoch_per_worker(1024 * 2),
n_batches_per_epoch_per_worker=self._n_batches_per_epoch_per_worker(self.n_training_batches_per_epoch),
**self._common_dataset_params())
self.val_dataset = dataset.NowcastingDataset(
t0_datetimes=self.val_t0_datetimes,
Expand Down Expand Up @@ -214,7 +221,8 @@ def contiguous_dataloader(self) -> torch.utils.data.DataLoader:
def _common_dataset_params(self) -> Dict:
return dict(
batch_size=self.batch_size,
n_samples_per_timestep=self.n_samples_per_timestep)
n_samples_per_timestep=self.n_samples_per_timestep,
collate_fn=self.collate_fn)

def _common_dataloader_params(self) -> Dict:
return dict(
Expand Down
5 changes: 3 additions & 2 deletions nowcasting_dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import numpy as np
from numbers import Number
from typing import List, Tuple, Iterable
from typing import List, Tuple, Iterable, Callable
import nowcasting_dataset
from nowcasting_dataset import data_sources
from dataclasses import dataclass
Expand All @@ -24,6 +24,7 @@ class NowcastingDataset(torch.utils.data.IterableDataset):
n_samples_per_timestep: int
data_sources: List[data_sources.DataSource]
t0_datetimes: pd.DatetimeIndex #: Valid t0 datetimes.
collate_fn: Callable = torch.utils.data._utils.collate.default_collate

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -89,7 +90,7 @@ def _get_batch(self) -> torch.Tensor:
for i in range(self.batch_size):
examples[i].update(examples_from_source[i])

return torch.utils.data._utils.collate.default_collate(examples)
return self.collate_fn(examples)

def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex:
# Pick random datetimes.
Expand Down

0 comments on commit 5bd1448

Please sign in to comment.