Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "datastores" to represent input data from zarr, npy, etc #66

Merged
merged 358 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
358 commits
Select commit Hold shift + click to select a range
c52f98e
npy mllam nearly done
leifdenby Jul 6, 2024
80f3639
minor adjustment
leifdenby Jul 7, 2024
048f8c6
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 11, 2024
5aaa239
add pooch and tweak pip cicd testing
leifdenby Jul 11, 2024
66c3b03
combine cicd tests with caching
leifdenby Jul 11, 2024
8566b8f
linting
leifdenby Jul 11, 2024
29bd9e5
add pyg dep
leifdenby Jul 11, 2024
bc7f028
set cirun aws region to frankfurt
leifdenby Jul 11, 2024
2070166
adapt image
leifdenby Jul 11, 2024
e4e86e5
set image
leifdenby Jul 11, 2024
1fba8fe
try different image
leifdenby Jul 11, 2024
02b77cf
add pooch to cicd
leifdenby Jul 11, 2024
b481929
add pdm gpu test
leifdenby Jul 16, 2024
bcec472
start work on readme
leifdenby Jul 16, 2024
c5beec9
Merge branch 'maint/deps-in-pyproject-toml' into datastore
leifdenby Jul 16, 2024
e89facc
Merge branch 'main' into maint/refactor-as-package
leifdenby Jul 16, 2024
0b5687a
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Jul 16, 2024
095fdbc
turn meps testdata download into pytest fixture
leifdenby Jul 16, 2024
49e9bfe
adapt README for package
leifdenby Jul 16, 2024
12cc02b
remove pdm cicd test (will be in separate PR)
leifdenby Jul 16, 2024
b47f50b
remove pdm in gitignore
leifdenby Jul 16, 2024
90d99ca
remove pdm and pyproject files (will be sep PR)
leifdenby Jul 16, 2024
a91eaaa
add pyproject.toml from main
leifdenby Jul 16, 2024
5508cea
clean out tests
leifdenby Jul 16, 2024
5c623c3
fix linting
leifdenby Jul 16, 2024
08ec168
add cli entrypoints import test
leifdenby Jul 16, 2024
d9cf7ba
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
3954f04
tweak cicd pytest execution
leifdenby Jul 16, 2024
f99fdce
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Jul 16, 2024
db9d96f
Update tests/test_mllam_dataset.py
leifdenby Jul 17, 2024
3c864b2
grid-shape ok
leifdenby Jul 17, 2024
1f54b0e
get_vars_names and units
leifdenby Jul 17, 2024
9b88160
get_vars_names and units 2
leifdenby Jul 17, 2024
a9fdad5
test for stats
leifdenby Jul 23, 2024
555154f
get_dataarray test
leifdenby Jul 24, 2024
8b8a77e
get_dataarray test
leifdenby Jul 24, 2024
41f11cd
boundary_mask
leifdenby Jul 24, 2024
a17de0f
get_xy
leifdenby Jul 24, 2024
0a38a7d
remove TrainingSample dataclass
leifdenby Jul 24, 2024
f65f6b5
test for WeatherDataset.__getitem__
leifdenby Jul 24, 2024
a35100e
test for graph creation
leifdenby Jul 24, 2024
cfb0618
more graph creation tests
leifdenby Jul 24, 2024
8698719
check for consistency of num features across splits
leifdenby Jul 24, 2024
3381404
test for single batch from mllam through model
leifdenby Jul 24, 2024
2a6796c
Add init files to expose classes in editable package
joeloskarsson Jul 24, 2024
8f4e0e0
Linting
joeloskarsson Jul 24, 2024
e657abb
working training_step with datastores!
Jul 25, 2024
effc99b
remove superfluous tests
Jul 25, 2024
a047026
fix for dataset length
Jul 25, 2024
d2c62ed
step length should be int
Jul 25, 2024
58f5d99
step length should be int
Jul 25, 2024
64d43a6
training working with mllam datastore!
Jul 25, 2024
07444f8
adapt neural_lam.train_model for datastores
Jul 25, 2024
d1b6fc1
fixes for npy
Jul 25, 2024
6fe19ac
npyfiles datastore complete
leifdenby Jul 26, 2024
fe65a4d
cleanup for datastore examples
leifdenby Jul 26, 2024
e533794
training on ohm with danra!
Jul 26, 2024
640ac05
use mllam-data-prep v0.2.0
Aug 5, 2024
0f16f13
remove py3.12 from pre-commit
Aug 5, 2024
724548e
cleanup
Aug 8, 2024
a1b2037
all tests passing!
Aug 12, 2024
e35958f
use mllam-data-prep v0.3.0
Aug 12, 2024
8b92318
delete requirements.txt
Aug 13, 2024
658836a
remove .DS_Store
Aug 13, 2024
421efed
use tmate in gpu pdm cicd
Aug 13, 2024
05f1e9f
remove requirements
Aug 13, 2024
3afe0e4
update pdm gpu cicd setup to pdm venv on nvme drive
Aug 13, 2024
f3d028b
don't try to use pdm venv in-project
Aug 13, 2024
2c35662
remove tmate
Aug 13, 2024
5f30255
update README with install instructions
Aug 14, 2024
b2b5631
changelog
Aug 14, 2024
c8ae829
update ci/cd badges to include gpu + gpu
Aug 14, 2024
e7cf2c0
Merge pull request #1 from mllam/package_inits
leifdenby Aug 14, 2024
0b72e9d
add pyproject-flake8 to precommit config
Aug 14, 2024
190d1de
use Flake8-pyproject instead
Aug 14, 2024
791af0a
update README
Aug 14, 2024
58fab84
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
dbe2e6d
Merge branch 'maint/refactor-as-package' into maint/deps-in-pyproject…
Aug 14, 2024
eac6e35
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
Aug 14, 2024
799d55e
linting fixes
Aug 14, 2024
57bbb81
train only 1 epoch in cicd and print to stdout
Aug 14, 2024
a955cee
log datastore config
Aug 14, 2024
0a79c74
cleanup doctrings
Aug 15, 2024
9f3c014
Merge branch 'maint/refactor-as-package' into datastore
leifdenby Aug 19, 2024
41364a8
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby Aug 19, 2024
3422298
update changelog
leifdenby Aug 19, 2024
689ef69
move dev deps optional dependencies group
leifdenby Aug 20, 2024
9a0d538
update cicd tests to install dev deps
leifdenby Aug 20, 2024
bddfcaf
update readme with new dev deps group
leifdenby Aug 20, 2024
b96cfdc
quote the skip step the install readme
leifdenby Aug 20, 2024
2600dee
remove unused files
leifdenby Aug 20, 2024
65a8074
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Aug 20, 2024
6adf6cc
revert to line length of 80
leifdenby Aug 20, 2024
46b37f8
revert docstring formatting changes
leifdenby Aug 20, 2024
3cd0f8b
pin numpy to <2.0.0
leifdenby Aug 20, 2024
826270a
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
leifdenby Aug 20, 2024
4ba22ea
Merge branch 'main' into feat/datastores
leifdenby Aug 20, 2024
1f661c6
fix flake8 linting errors
leifdenby Aug 20, 2024
4838872
Update neural_lam/weather_dataset.py
leifdenby Sep 8, 2024
b59e7e5
Update neural_lam/datastore/multizarr/create_normalization_stats.py
leifdenby Sep 8, 2024
75b1fe7
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
7e736cb
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
613a7e2
Update neural_lam/datastore/npyfiles/store.py
leifdenby Sep 8, 2024
65e199b
Update tests/test_training.py
leifdenby Sep 8, 2024
4435e26
Update tests/test_datasets.py
leifdenby Sep 8, 2024
4693408
Update README.md
leifdenby Sep 8, 2024
2dfed2c
update README
leifdenby Sep 10, 2024
c3d033d
Merge branch 'main' of https://github.com/mllam/neural-lam into feat/…
leifdenby Sep 10, 2024
4a70268
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
66c663f
column_water -> open_water_fraction
leifdenby Sep 10, 2024
11a7978
fix linting
leifdenby Sep 10, 2024
a41c314
static data same for all splits
leifdenby Sep 10, 2024
6f1efd6
forcing_window_size from args
leifdenby Sep 10, 2024
bacb9ec
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
4a9db4e
only use first ensemble member in datastores
leifdenby Sep 10, 2024
4fc2448
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
bcaa919
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
90bc594
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
5bda935
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
8e7931d
remove all multizarr functionality
leifdenby Sep 10, 2024
6998683
cleanup and test fixes for recent changes
leifdenby Sep 10, 2024
c415008
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Sep 10, 2024
735d324
fix linting
leifdenby Sep 10, 2024
5f2d919
remove multizar example files
leifdenby Sep 10, 2024
5263d2c
normalization -> standardization
leifdenby Sep 10, 2024
ba1bec3
fix import for tests
leifdenby Sep 10, 2024
d04d15e
Update neural_lam/datastore/base.py
leifdenby Sep 10, 2024
743d7a1
fix coord issues and add datastore example plotting cli
leifdenby Sep 12, 2024
ac10d7d
add lru_cache to get_xy_extent
leifdenby Sep 12, 2024
bf8172a
MLLAMDatastore -> MDPDatastore
leifdenby Sep 12, 2024
90ca400
missed renames for MDPDatastore
leifdenby Sep 12, 2024
154139d
update graph plot for datastores
leifdenby Sep 12, 2024
50ee0b0
use relative import
leifdenby Sep 12, 2024
7dfd570
add long_names and refactor npyfiles create weights
leifdenby Sep 12, 2024
2b45b5a
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
aee0b1c
Update neural_lam/weather_dataset.py
leifdenby Sep 23, 2024
8453c2b
Update neural_lam/models/ar_model.py
leifdenby Sep 27, 2024
7f32557
Update neural_lam/weather_dataset.py
leifdenby Sep 27, 2024
67998b8
read projection from datastore config extra section
leifdenby Sep 27, 2024
ac7e46a
NpyFilesDatastore -> NpyFilesDatastoreMEPS
leifdenby Sep 27, 2024
b7bf506
revert tp training with 1 AR step by default
leifdenby Sep 27, 2024
5df2ecf
add missing kwarg to BaseHiGraphModel.__init__
leifdenby Sep 27, 2024
d4d438f
add missing kwarg to HiLAM.__init__
leifdenby Sep 27, 2024
1889771
add missing kwarg to HiLAMParallel
leifdenby Sep 27, 2024
2c3bbde
check that for enough forecast steps given ar_steps
leifdenby Sep 27, 2024
f0a151b
remove numpy<2.0.0 version cap
leifdenby Sep 27, 2024
f3566b0
tweak print statement working in mdp
Oct 1, 2024
dba94b3
fix missed removed argument from cli
Oct 1, 2024
bca1482
remove wandb config log comment, we log now
Oct 1, 2024
fc973c4
ensure loading from checkpoint during train possible
Oct 1, 2024
9fcf06e
get step_length from datastore in plot_error_map
leifdenby Oct 1, 2024
2bbe666
remove step_legnth attr in ARModel
leifdenby Oct 1, 2024
b41ed2f
remove unused obs_mask arg for vis.plot_prediction
leifdenby Oct 1, 2024
7e46194
ensure no reference to multizarr "data_config"
leifdenby Oct 1, 2024
b57bc7a
introduce neural-lam config
leifdenby Oct 2, 2024
2b30715
include meps neural-lam config example
leifdenby Oct 2, 2024
8e7b2e6
fix extra space typo in BaseDatastore
leifdenby Oct 2, 2024
e0300fb
add check and print of train/test/val split in MDPDatastore
leifdenby Oct 2, 2024
d1b4ca7
BaseCartesianDatastore -> BaseRegularGridDatastore
leifdenby Oct 3, 2024
de46fb4
removed `control_only' arg
sadamov Oct 23, 2024
c1a7159
All flags are explicit
sadamov Oct 23, 2024
5b02761
removed multizarr, obsolete
sadamov Oct 23, 2024
f80fe4a
robust import of conftest
sadamov Oct 23, 2024
0222759
fixed torch List typing
sadamov Oct 23, 2024
3d91f7c
graph creation is handled by WMG
sadamov Oct 23, 2024
65cb4a8
graph creation now handled in WGM
sadamov Oct 23, 2024
b1e2097
Making sure that all tensors, arrays and datesets follow the same ord…
sadamov Oct 24, 2024
96900c1
expanding the dummy class to support all tests
sadamov Oct 24, 2024
4281a12
clarify comment about array shape
sadamov Oct 24, 2024
84ea4d3
Add caching decorator
sadamov Oct 24, 2024
930a13d
by default dataset is written to_zarr
sadamov Oct 24, 2024
2b9f00d
prevent removal of old zarr-archives
sadamov Oct 24, 2024
d9e4822
Removed dev-dependencies
sadamov Oct 24, 2024
4bed96e
rename sampling to slicing
sadamov Oct 24, 2024
1f58798
Align datastore_config_path arguments
sadamov Oct 25, 2024
53f32aa
imlement flexible window slices for forcings (past & future)
sadamov Oct 25, 2024
6a20a9d
Expanded docstring for stacked get_xy
sadamov Oct 25, 2024
5c9b4c5
Implementation of feature weights
sadamov Oct 25, 2024
0a41d0c
Removing some obsolete "better"-comments
sadamov Oct 25, 2024
239aad7
Bugfixes and better documentation of time slicing operations
sadamov Oct 27, 2024
98dde82
reintroduction of create_graphp
sadamov Oct 28, 2024
01e6dff
implementation of state_feature_weights
sadamov Oct 28, 2024
0878fce
bugfix for length of forcing window
sadamov Oct 28, 2024
1b9d253
formatting
sadamov Oct 28, 2024
85aa170
update README.md to reflect renaming of create_mesh to create_graph
khintz Oct 29, 2024
b9c9951
update instructions on creating graph
khintz Oct 29, 2024
9e586d8
Replace component_dependencies figure with mermaid diagram
khintz Oct 29, 2024
f6c6404
add index selection to datastore example plot cli
leifdenby Nov 4, 2024
8421a6a
more work on readme
leifdenby Nov 5, 2024
3c045c7
Merge pull request #2 from sadamov/feat/datastores
leifdenby Nov 6, 2024
8deace8
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby Nov 6, 2024
8149d65
Use "datastore" in config filename for datastores
leifdenby Nov 6, 2024
514b0d1
add loss-weighting config and implementations
leifdenby Nov 8, 2024
e642cb0
Bugfix for forcing window calculation
sadamov Nov 8, 2024
599917d
Fix shift by init_steps
sadamov Nov 8, 2024
ef3da41
Cover cases where include_past_forcing > init_steps
sadamov Nov 8, 2024
74828fa
Merge main branch into datastores and resolve README conflicts
joeloskarsson Nov 11, 2024
8cc6c3d
Add test for dataset length using different configs
joeloskarsson Nov 11, 2024
731910f
Updates for datastore examples and neural-lam config
leifdenby Nov 12, 2024
ff02af7
linting fix
leifdenby Nov 12, 2024
b5844c0
only prepare npymeps test example files that don't exist
leifdenby Nov 12, 2024
b7a10ef
ensure dimension order from BaseRegularGridDatastore.stack_grid_coords
leifdenby Nov 12, 2024
31ebfc8
get datastore static data in ARModel without defining split
leifdenby Nov 12, 2024
f022365
Update neural_lam/train_model.py
leifdenby Nov 12, 2024
b33e863
Update neural_lam/weather_dataset.py
leifdenby Nov 12, 2024
a8362ce
suggest to reduce ar_steps and forcing window with small dataset
leifdenby Nov 12, 2024
b2e0874
adapt dummy datastore to generate on equal area grid
leifdenby Nov 12, 2024
772cc20
adapt all cli to use --config arg instead of `config`
leifdenby Nov 12, 2024
89b10b5
add test for datastores example plot function
leifdenby Nov 12, 2024
d355ef5
bugfix for earlier unstacking dim order fix in datastores
leifdenby Nov 12, 2024
1121d9f
add enforcement of datastores output dimension order
leifdenby Nov 13, 2024
9afaf6e
fix bugs introduced with dimension order during stack/unstack
leifdenby Nov 13, 2024
3df627f
update meps test to point to new dataset on aws
leifdenby Nov 13, 2024
89fac82
remove unused print statement
leifdenby Nov 13, 2024
a95eb5a
fix config-path arg bug in CLIs
leifdenby Nov 13, 2024
b56e47a
renaming the forcing arguments
sadamov Nov 13, 2024
258079c
Merge branch 'feat/datastores' of github.com:leifdenby/neural-lam int…
sadamov Nov 13, 2024
d458677
fix bug for datastore ref in ARModel.plot_examples()
leifdenby Nov 13, 2024
223db37
improved docstring for forcings
sadamov Nov 13, 2024
bcc3e51
Adjusting index of flux based on datastore
sadamov Nov 13, 2024
f97719b
defined forcings to be 0, meaningless for stats_calc in MEPS
sadamov Nov 13, 2024
df4d39c
fix typo (missing datastore) in ARModel.on_test_epoch_end
leifdenby Nov 13, 2024
46f161c
setting ar_steps to 63 for stats calc in MEPS
sadamov Nov 13, 2024
38cdfe6
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
sadamov Nov 13, 2024
f36d1ce
more verbose ci/cd testing and update meps cache
leifdenby Nov 13, 2024
cd53b21
Bugfix, `idx` removed from forecast forcing window indices
sadamov Nov 13, 2024
98706c1
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
sadamov Nov 13, 2024
7c62778
absolute imports
sadamov Nov 14, 2024
8ddd0de
change datastore arg type
sadamov Nov 14, 2024
0f24924
calling cli() instead of main();
sadamov Nov 14, 2024
acb8ffa
add test for state/forcing values from time-slicing
leifdenby Nov 14, 2024
7fe1726
Update tests/test_time_slicing.py
leifdenby Nov 14, 2024
4fe2cea
Fix typo in time slicing test
joeloskarsson Nov 14, 2024
dfec1ec
formatting
sadamov Nov 16, 2024
f56a999
enable worker argument, but set to zero for tests
sadamov Nov 16, 2024
4a5ae6c
reduce workers to zero
sadamov Nov 16, 2024
9f0120b
revert num_workers to 1 in test
sadamov Nov 16, 2024
665368d
Fix missing datastore kind in plot script
joeloskarsson Nov 18, 2024
a90a979
replace transpose in WeatherDataset.__getitem__ with assert
leifdenby Nov 18, 2024
6fedea5
Merge torch.load change from main into datastores
joeloskarsson Nov 18, 2024
0180ca0
Merge branch 'main' into datastores
joeloskarsson Nov 18, 2024
93c20fc
default config path should be None for datastore plote example
leifdenby Nov 18, 2024
f6da2b2
return stacked coords by default from BaseRegularGridDatastore.get_xy()
leifdenby Nov 18, 2024
fc6be8d
Fix typos and clarifications in readme
joeloskarsson Nov 19, 2024
9787869
Fix dim ordering in time slicing test
joeloskarsson Nov 19, 2024
4cb44de
Reduce example size of single batch and training tests to save memory
joeloskarsson Nov 19, 2024
daf1dbc
Add changelog entry
joeloskarsson Nov 19, 2024
f922542
use mllam-data-prep v0.5.0
leifdenby Nov 20, 2024
9e8c08f
add support for setting globe properties in projection
leifdenby Nov 20, 2024
4302d58
update path for meps data chache in ci/cd
leifdenby Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove TrainingSample dataclass
  • Loading branch information
leifdenby committed Jul 24, 2024
commit 0a38a7d453d0a2cdb73f38d19b7b6af8adf32b34
8 changes: 4 additions & 4 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,14 @@ def create_graph(

def create_graph_from_datastore(
datastore: BaseCartesianDatastore,
graph_dir_path: str,
output_root_path: str,
n_max_levels: int = None,
hierarchical: bool = False,
create_plot: bool = False,
):
xy = datastore.get_xy(category="state", stacked=False)
create_graph(
graph_dir_path=graph_dir_path,
graph_dir_path=output_root_path,
xy=xy,
n_max_levels=n_max_levels,
hierarchical=hierarchical,
Expand All @@ -505,7 +505,7 @@ def cli(input_args=None):
help="path to the data store",
)
parser.add_argument(
"--graph",
"--name",
type=str,
default="multiscale",
help="Name to save graph as (default: multiscale)",
Expand Down Expand Up @@ -536,7 +536,7 @@ def cli(input_args=None):

create_graph_from_datastore(
datastore=datastore,
graph_dir_path=os.path.join("graphs", args.graph),
output_root_path=os.path.join(datastore.root_path, "graphs", args.name),
n_max_levels=args.levels,
hierarchical=args.hierarchical,
create_plot=args.plot,
Expand Down
14 changes: 14 additions & 0 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard library
import abc
import dataclasses
from pathlib import Path
from typing import List, Union

# Third-party
Expand Down Expand Up @@ -30,6 +31,19 @@ class BaseDatastore(abc.ABC):
is_ensemble: bool = False
is_forecast: bool = False

@property
@abc.abstractmethod
def root_path(self) -> Path:
"""The root path to the datastore. It is relative to this that any
derived files (for example the graph components) are stored.

Returns
-------
pathlib.Path
The root path to the datastore.
"""
pass

@property
@abc.abstractmethod
def step_length(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

@property
def root_path(self) -> Path:
return Path(self._config_path.parent)

def step_length(self) -> int:
da_dt = self._ds["time"].diff("time")
return da_dt.dt.seconds[0] // 3600
Expand Down
10 changes: 9 additions & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def __init__(self, args, datastore, forcing_window_size):
# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
self.hierarchical, graph_ldict = utils.load_graph(args.graph)
graph_dir_path = datastore.root_path / "graph" / args.graph
self.hierarchical, graph_ldict = utils.load_graph(
graph_dir_path=graph_dir_path
)
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
if isinstance(attr_value, torch.Tensor):
Expand Down Expand Up @@ -102,6 +105,11 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
batch_size = prev_state.shape[0]

print(f"prev_state.shape: {prev_state.shape}")
print(f"prev_prev_state.shape: {prev_prev_state.shape}")
print(f"forcing.shape: {forcing.shape}")
print(f"grid_static_features.shape: {self.grid_static_features.shape}")

# Create full grid node features of shape (B, num_grid_nodes, grid_dim)
grid_features = torch.cat(
(
Expand Down
4 changes: 1 addition & 3 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ def __iter__(self):
return (self[i] for i in range(len(self)))


def load_graph(graph_name, device="cpu"):
def load_graph(graph_dir_path, device="cpu"):
"""Load all tensors representing the graph."""
# Define helper lambda function
graph_dir_path = os.path.join("graphs", graph_name)

def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
Expand Down
78 changes: 2 additions & 76 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Standard library
import dataclasses
import warnings

# Third-party
Expand All @@ -12,74 +11,6 @@
from neural_lam.datastore.base import BaseDatastore


@dataclasses.dataclass
class TrainingSample:
"""A dataclass to hold a single training sample of `ar_steps`
autoregressive steps, which consists of the initial states, target states,
forcing and batch times. The initial and target states should have
`d_features` features, and the forcing should have `d_windowed_forcing`
features.

Parameters
----------
init_states : torch.Tensor
The initial states of the training sample,
shape (2, N_grid, d_features).
target_states : torch.Tensor
The target states of the training sample,
shape (ar_steps, N_grid, d_features).
forcing : torch.Tensor
The forcing of the training sample,
shape (ar_steps, N_grid, d_windowed_forcing).
batch_times : np.ndarray
The times of the batch, shape (ar_steps,).
"""

init_states: torch.Tensor
target_states: torch.Tensor
forcing: torch.Tensor
batch_times: np.ndarray

def __post_init__(self):
"""Validate the shapes of the tensors match between the different
components of the training sample.

init_states: (2, N_grid, d_features)
target_states: (ar_steps, N_grid, d_features)
forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
"""
assert self.init_states.shape[0] == 2
_, N_grid, d_features = self.init_states.shape
N_pred_steps = self.target_states.shape[0]

# check number of grid points
if not (
self.target_states.shape[1] == self.target_states.shape[1] == N_grid
):
raise Exception(
"Number of grid points do not match, got "
f"{self.target_states.shape[1]=} and "
f"{self.target_states.shape[2]=}, expected {N_grid=}"
)

# check number of features for init and target states
assert self.target_states.shape[2] == d_features

# check that target, forcing and batch times have the same number of
# prediction steps
if not (
self.target_states.shape[0]
== self.forcing.shape[0]
== self.batch_times.shape[0]
== N_pred_steps
):
raise Exception(
"Number of prediction steps do not match, got "
f"{self.target_states.shape[0]=}, {self.forcing.shape[0]=} and "
f"{self.batch_times.shape[0]=}, expected {N_pred_steps=}"
)


class WeatherDataset(torch.utils.data.Dataset):
"""Dataset class for weather data.

Expand Down Expand Up @@ -268,7 +199,7 @@ def __getitem__(self, idx):
da_init_states = da_state.isel(time=slice(None, 2))
da_target_states = da_state.isel(time=slice(2, None))

batch_times = da_forcing_windowed.time
batch_times = da_forcing_windowed.time.values.astype(float)

if self.standardize:
da_init_states = (
Expand Down Expand Up @@ -300,12 +231,7 @@ def __getitem__(self, idx):
# forcing: (ar_steps, N_grid, d_windowed_forcing)
# batch_times: (ar_steps,)

return TrainingSample(
init_states=init_states,
target_states=target_states,
forcing=forcing,
batch_times=batch_times,
)
return init_states, target_states, forcing, batch_times


class WeatherDataModule(pl.LightningDataModule):
Expand Down
51 changes: 32 additions & 19 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
In addition BaseCartesianDatastore must have the following methods and attributes:
- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- [ ] `get_xy` (method): Return the x, y coordinates of the dataset.
- [ ] `coords_projection` (property): Projection object for the coordinates.
- [ ] `grid_shape_state` (property): Shape of the grid for the state variables.
- [x] `get_xy` (method): Return the x, y coordinates of the dataset.
- [x] `coords_projection` (property): Projection object for the coordinates.
- [x] `grid_shape_state` (property): Shape of the grid for the state variables.
"""

# Standard library
from pathlib import Path

# Third-party
import cartopy.crs as ccrs
import numpy as np
Expand Down Expand Up @@ -51,17 +54,24 @@
)


def _init_datastore(datastore_name):
def init_datastore(datastore_name):
DatastoreClass = DATASTORES[datastore_name]
return DatastoreClass(**EXAMPLES[datastore_name])


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_root_path(datastore_name):
"""Check that the `datastore.root_path` property is implemented."""
datastore = init_datastore(datastore_name)
assert isinstance(datastore.root_path, Path)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_datastore_grid_xy(datastore_name):
"""Use the `datastore.get_xy` method to get the x, y coordinates of the
dataset and check that the shape is correct against the
`datastore.grid_shape_state` property."""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

# check the shapes of the xy grid
grid_shape = datastore.grid_shape_state
Expand All @@ -78,14 +88,6 @@ def test_datastore_grid_xy(datastore_name):
assert xy.shape == (2, ny, nx)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_projection(datastore_name):
"""Check that the `datastore.coords_projection` property is implemented."""
datastore = _init_datastore(datastore_name)

assert isinstance(datastore.coords_projection, ccrs.Projection)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_vars(datastore_name):
"""Check that results of.
Expand All @@ -97,7 +99,7 @@ def test_get_vars(datastore_name):
are consistent (as in the number of variables are the same) and that the
return types of each are correct.
"""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
units = datastore.get_vars_units(category)
Expand All @@ -114,7 +116,7 @@ def test_get_vars(datastore_name):
def test_get_normalization_dataarray(datastore_name):
"""Check that the `datastore.get_normalization_dataarray` method is
implemented."""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
ds_stats = datastore.get_normalization_dataarray(category=category)
Expand Down Expand Up @@ -144,7 +146,7 @@ def test_get_dataarray(datastore_name):
And that it returns an xarray DataArray with the correct dimensions.
"""

datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
for split in ["train", "val", "test"]:
Expand Down Expand Up @@ -175,7 +177,7 @@ def test_get_dataarray(datastore_name):
def test_boundary_mask(datastore_name):
"""Check that the `datastore.boundary_mask` property is implemented and
that the returned object is an xarray DataArray with the correct shape."""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)
da_mask = datastore.boundary_mask

assert isinstance(da_mask, xr.DataArray)
Expand All @@ -194,7 +196,7 @@ def test_boundary_mask(datastore_name):
def test_get_xy_extent(datastore_name):
"""Check that the `datastore.get_xy_extent` method is implemented and that
the returned object is a tuple of the correct length."""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
Expand All @@ -216,7 +218,7 @@ def test_get_xy_extent(datastore_name):
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_xy(datastore_name):
"""Check that the `datastore.get_xy` method is implemented."""
datastore = _init_datastore(datastore_name)
datastore = init_datastore(datastore_name)

if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
Expand All @@ -240,3 +242,14 @@ def test_get_xy(datastore_name):
assert xy_unstacked.shape[0] == 2
assert xy_unstacked.shape[1] == ny
assert xy_unstacked.shape[2] == nx


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_get_projection(datastore_name):
"""Check that the `datastore.coords_projection` property is implemented."""
datastore = init_datastore(datastore_name)

if not isinstance(datastore, BaseCartesianDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")

assert isinstance(datastore.coords_projection, ccrs.Projection)
2 changes: 1 addition & 1 deletion tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_mllam():

create_graph_from_datastore(
datastore=datastore,
graph_dir_path="tests/datastore_configs/mllam/graph",
output_root_path="tests/datastore_configs/mllam/graph",
)

model = GraphLAM( # noqa
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multizarr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_create_graph_analysis_dataset():
config_path=DATASTORE_PATH / "data_config.yaml"
)
create_graph_from_datastore(
datastore=datastore, graph_dir_path=DATASTORE_PATH / "graph"
datastore=datastore, output_root_path=DATASTORE_PATH / "graph"
)

# test cli
Expand Down