-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtest_time_slicing.py
146 lines (118 loc) · 4.46 KB
/
test_time_slicing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Third-party
import numpy as np
import pytest
import xarray as xr
# First-party
from neural_lam.datastore.base import BaseDatastore
from neural_lam.weather_dataset import WeatherDataset
class SinglePointDummyDatastore(BaseDatastore):
step_length = 1
config = None
coords_projection = None
num_grid_points = 1
root_path = None
def __init__(self, time_values, state_data, forcing_data, is_forecast):
self._time_values = np.array(time_values)
self._state_data = np.array(state_data)
self._forcing_data = np.array(forcing_data)
self.is_forecast = is_forecast
if is_forecast:
assert self._state_data.ndim == 2
else:
assert self._state_data.ndim == 1
def get_num_data_vars(self, category):
return 1
def get_dataarray(self, category, split):
if category == "state":
values = self._state_data
elif category == "forcing":
values = self._forcing_data
else:
raise NotImplementedError(category)
if self.is_forecast:
raise NotImplementedError()
else:
da = xr.DataArray(
values, dims=["time"], coords={"time": self._time_values}
)
# add `{category}_feature` and `grid_index` dimensions
da = da.expand_dims("grid_index")
da = da.expand_dims(f"{category}_feature")
dim_order = self.expected_dim_order(category=category)
return da.transpose(*dim_order)
def get_standardization_dataarray(self, category):
raise NotImplementedError()
def get_xy(self, category):
raise NotImplementedError()
def get_vars_units(self, category):
raise NotImplementedError()
def get_vars_names(self, category):
raise NotImplementedError()
def get_vars_long_names(self, category):
raise NotImplementedError()
ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
@pytest.mark.parametrize(
"ar_steps,num_past_forcing_steps,num_future_forcing_steps",
[[3, 0, 0], [3, 1, 0], [3, 2, 0], [3, 3, 0]],
)
def test_time_slicing_analysis(
ar_steps, num_past_forcing_steps, num_future_forcing_steps
):
# state and forcing variables have only on dimension, `time`
time_values = np.datetime64("2020-01-01") + np.arange(
len(ANALYSIS_STATE_VALUES)
)
assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values)
datastore = SinglePointDummyDatastore(
state_data=ANALYSIS_STATE_VALUES,
forcing_data=FORCING_VALUES,
time_values=time_values,
is_forecast=False,
)
dataset = WeatherDataset(
datastore=datastore,
ar_steps=ar_steps,
num_future_forcing_steps=num_future_forcing_steps,
num_past_forcing_steps=num_past_forcing_steps,
standardize=False,
)
sample = dataset[0]
init_states, target_states, forcing, _ = [
tensor.numpy() for tensor in sample
]
expected_init_states = [0, 1]
if ar_steps == 3:
expected_target_states = [2, 3, 4]
else:
raise NotImplementedError()
if num_past_forcing_steps == num_future_forcing_steps == 0:
expected_forcing_values = [[12], [13], [14]]
elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0:
expected_forcing_values = [[11, 12], [12, 13], [13, 14]]
elif num_past_forcing_steps == 2 and num_future_forcing_steps == 0:
expected_forcing_values = [[10, 11, 12], [11, 12, 13], [12, 13, 14]]
elif num_past_forcing_steps == 3 and num_future_forcing_steps == 0:
expected_init_states = [1, 2]
expected_target_states = [3, 4, 5]
expected_forcing_values = [
[10, 11, 12, 13],
[11, 12, 13, 14],
[12, 13, 14, 15],
]
else:
raise NotImplementedError()
# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
# forcing: (ar_steps, N_grid, d_windowed_forcing)
# target_times: (ar_steps,)
assert init_states.shape == (2, 1, 1)
assert init_states[:, 0, 0].tolist() == expected_init_states
assert target_states.shape == (3, 1, 1)
assert target_states[:, 0, 0].tolist() == expected_target_states
assert forcing.shape == (
3,
1,
1 + num_past_forcing_steps + num_future_forcing_steps,
)
np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values))