Skip to content

Commit 734da01

Browse files
committed
🧪 EnglandCOVID and HundaryCP tests added
1 parent bcd1b12 commit 734da01

File tree

6 files changed

+104
-15
lines changed

6 files changed

+104
-15
lines changed

stgraph/dataset/dynamic/EnglandCovidDataLoader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
The number of time lags (default is 8)
5252
cutoff_time : int, optional
5353
The cutoff timestamp for the temporal dataset (default is None)
54+
redownload : bool, optional (default is False)
55+
Redownload the dataset online and save to cache
5456
5557
Attributes
5658
----------

stgraph/dataset/static/CoraDataLoader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None:
5252
Flag to control whether to display verbose info (default is False)
5353
url : str, optional
5454
The URL from where the dataset is downloaded online (default is None)
55+
redownload : bool, optional (default is False)
56+
Redownload the dataset online and save to cache
5557
5658
Attributes
5759
----------

stgraph/dataset/temporal/HungaryCPDataLoader.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55

66
class HungaryCPDataLoader(STGraphTemporalDataset):
7-
def __init__(self, verbose=False, url=None, lags=4, cutoff_time=None) -> None:
7+
def __init__(
8+
self,
9+
verbose: bool = False,
10+
url: str = None,
11+
lags: int = 4,
12+
cutoff_time: int = None,
13+
redownload: bool = False,
14+
) -> None:
815
r"""County level chicken pox cases in Hungary
916
1017
This dataset comprises information on weekly occurrences of chickenpox
@@ -48,6 +55,8 @@ def __init__(self, verbose=False, url=None, lags=4, cutoff_time=None) -> None:
4855
The number of time lags (default is 4)
4956
cutoff_time : int, optional
5057
The cutoff timestamp for the temporal dataset (default is None)
58+
redownload : bool, optional (default is False)
59+
Redownload the dataset online and save to cache
5160
5261
Attributes
5362
----------
@@ -69,10 +78,15 @@ def __init__(self, verbose=False, url=None, lags=4, cutoff_time=None) -> None:
6978

7079
super().__init__()
7180

72-
assert lags > 0, "lags should be a positive integer"
73-
assert type(lags) == int, "lags should be of type int"
74-
assert cutoff_time > 0, "cutoff_time should be a positive integer"
75-
assert type(cutoff_time) == int, "cutoff_time should be a positive integer"
81+
if type(lags) != int:
82+
raise TypeError("lags must be of type int")
83+
if lags < 0:
84+
raise ValueError("lags must be a positive integer")
85+
86+
if cutoff_time != None and type(cutoff_time) != int:
87+
raise TypeError("cutoff_time must be of type int")
88+
if cutoff_time != None and cutoff_time < 0:
89+
raise ValueError("cutoff_time must be a positive integer")
7690

7791
self.name = "Hungary_Chickenpox"
7892
self._verbose = verbose
@@ -84,6 +98,9 @@ def __init__(self, verbose=False, url=None, lags=4, cutoff_time=None) -> None:
8498
else:
8599
self._url = url
86100

101+
if redownload and self._has_dataset_cache():
102+
self._delete_cached_dataset()
103+
87104
if self._has_dataset_cache():
88105
self._load_dataset()
89106
else:
@@ -142,9 +159,11 @@ def _set_edge_weights(self):
142159
def _set_targets_and_features(self):
143160
r"""Calculates and sets the target and feature attributes"""
144161
stacked_target = np.array(self._dataset["FX"])
145-
self._all_targets = np.array(
146-
[stacked_target[i, :].T for i in range(stacked_target.shape[0])]
147-
)
162+
163+
self._all_targets = [
164+
stacked_target[i + self._lags, :].T
165+
for i in range(self.gdata["total_timestamps"] - self._lags)
166+
]
148167

149168
def get_edges(self):
150169
r"""Returns the edge list"""

tests/dataset/dynamic/test_EnglandCovidDataLoader.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
def EnglandCovidDataCheck(eng_covid: EnglandCovidDataLoader):
8-
# test for gdata
98
assert eng_covid.gdata["total_timestamps"] == (
109
61 if not eng_covid._cutoff_time else eng_covid._cutoff_time
1110
)
@@ -34,16 +33,32 @@ def EnglandCovidDataCheck(eng_covid: EnglandCovidDataLoader):
3433
for i in range(len(edge_list)):
3534
assert len(edge_list[i]) == len(edge_weights[i])
3635

37-
# test for features and targets
38-
# TODO:
36+
all_features = eng_covid.get_all_features()
37+
all_targets = eng_covid.get_all_targets()
38+
39+
assert len(all_features) == eng_covid.gdata["total_timestamps"] - eng_covid._lags
40+
41+
assert all_features[0].shape == (
42+
eng_covid.gdata["num_nodes"]["0"],
43+
eng_covid._lags,
44+
)
45+
46+
assert len(all_targets) == eng_covid.gdata["total_timestamps"] - eng_covid._lags
47+
48+
assert all_targets[0].shape == (eng_covid.gdata["num_nodes"]["0"],)
3949

4050

4151
def test_EnglandCovidDataLoader():
42-
eng_covid = EnglandCovidDataLoader()
52+
eng_covid = EnglandCovidDataLoader(verbose=True)
4353
eng_covid_1 = EnglandCovidDataLoader(cutoff_time=30)
4454
eng_covid_2 = EnglandCovidDataLoader(
4555
url="https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json"
4656
)
57+
eng_covid_3 = EnglandCovidDataLoader(lags=12)
58+
# eng_covid_4 = EnglandCovidDataLoader(redownload=True)
4759

4860
EnglandCovidDataCheck(eng_covid)
4961
EnglandCovidDataCheck(eng_covid_1)
62+
# EnglandCovidDataCheck(eng_covid_2)
63+
EnglandCovidDataCheck(eng_covid_3)
64+
# EnglandCovidDataCheck(eng_covid_4)

tests/dataset/static/test_CoraDataLoader.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import numpy as np
2-
import urllib.request
3-
41
from stgraph.dataset import CoraDataLoader
52

63

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
3+
from stgraph.dataset import HungaryCPDataLoader
4+
5+
6+
def HungaryCPDataChecker(hungary: HungaryCPDataLoader):
7+
assert hungary.gdata["total_timestamps"] == (
8+
521 if not hungary._cutoff_time else hungary._cutoff_time
9+
)
10+
11+
assert hungary.gdata["num_nodes"] == 20
12+
assert hungary.gdata["num_edges"] == 102
13+
14+
edges = hungary.get_edges()
15+
edge_weights = hungary.get_edge_weights()
16+
all_targets = hungary.get_all_targets()
17+
18+
assert len(edges) == 102
19+
assert len(edges[0]) == 2
20+
21+
assert len(edge_weights) == 102
22+
23+
assert len(all_targets) == hungary.gdata["total_timestamps"] - hungary._lags
24+
assert all_targets[0].shape == (hungary.gdata["num_nodes"],)
25+
26+
27+
def test_HungaryCPDataLoader():
28+
hungary_1 = HungaryCPDataLoader(verbose=True)
29+
hungary_2 = HungaryCPDataLoader(lags=6)
30+
hungary_3 = HungaryCPDataLoader(cutoff_time=100)
31+
hungary_4 = HungaryCPDataLoader(
32+
url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/HungaryCP.json"
33+
)
34+
35+
HungaryCPDataChecker(hungary_1)
36+
HungaryCPDataChecker(hungary_2)
37+
HungaryCPDataChecker(hungary_3)
38+
# HungaryCPDataChecker(hungary_4)
39+
40+
with pytest.raises(TypeError) as exec:
41+
HungaryCPDataLoader(lags="lags")
42+
assert str(exec.value) == "lags must be of type int"
43+
44+
with pytest.raises(ValueError) as exec:
45+
HungaryCPDataLoader(lags=-1)
46+
assert str(exec.value) == "lags must be a positive integer"
47+
48+
with pytest.raises(TypeError) as exec:
49+
HungaryCPDataLoader(cutoff_time="time")
50+
assert str(exec.value) == "cutoff_time must be of type int"
51+
52+
with pytest.raises(ValueError) as exec:
53+
HungaryCPDataLoader(cutoff_time=-1)
54+
assert str(exec.value) == "cutoff_time must be a positive integer"

0 commit comments

Comments
 (0)