Skip to content

Commit 4cee102

Browse files
committed
improve tests
1 parent 4822652 commit 4cee102

File tree

8 files changed

+148
-64
lines changed

8 files changed

+148
-64
lines changed

config/network/network.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ defaults:
44
- dynamics: dynamics
55
- edge_config: edge_config
66
- node_config: node_config
7+
- stimulus_config: stimulus_config
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: Stimulus
2+
init_buffer: false

examples/01_flyvision_connectome.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@
152152
"# This json-file includes a list of cell types (`nodes`) and average convolutional filters\n",
153153
"# (anatomical receptive fields) (`edges`) that are scattered across a regular hexagonal lattice\n",
154154
"# of 15 column extent and stored on the hierarchical filesystem as h5-files.\n",
155-
"config = dict(file=connectome_file, extent=15, n_syn_fill=1)\n",
155+
"config = dict(file=connectome_file.name, extent=15, n_syn_fill=1)\n",
156156
"connectome = ConnectomeFromAvgFilters(config)"
157157
]
158158
},

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def connectome(tmp_path_factory):
1414
return ConnectomeFromAvgFilters(
1515
tmp_path_factory.mktemp("tmp") / "test",
16-
dict(file=connectome_file, extent=1, n_syn_fill=1),
16+
dict(file=connectome_file.name, extent=1, n_syn_fill=1),
1717
)
1818

1919

tests/test_paths.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
import tempfile
3+
from pathlib import Path
4+
5+
import flyvis
6+
7+
8+
def test_default_root_dir():
9+
"""Test that the default root directory is correctly set relative to the package."""
10+
# Clear any existing FLYVIS_ROOT_DIR environment variable
11+
if "FLYVIS_ROOT_DIR" in os.environ:
12+
del os.environ["FLYVIS_ROOT_DIR"]
13+
14+
# Get the resolved root directory
15+
root_dir = flyvis.resolve_root_dir()
16+
17+
# Expected default is repo_dir/data
18+
expected_dir = flyvis.repo_dir / "data"
19+
assert root_dir == expected_dir.absolute()
20+
21+
22+
def test_env_var_root_dir():
23+
"""Test that FLYVIS_ROOT_DIR environment variable is respected."""
24+
with tempfile.TemporaryDirectory() as temp_dir:
25+
# Set environment variable to temp directory
26+
os.environ["FLYVIS_ROOT_DIR"] = temp_dir
27+
28+
# Get the resolved root directory
29+
root_dir = flyvis.resolve_root_dir()
30+
31+
assert root_dir == Path(temp_dir).absolute()
32+
33+
34+
def test_dotenv_file():
35+
"""Test that .env file is properly loaded."""
36+
with tempfile.TemporaryDirectory() as temp_dir:
37+
temp_path = Path(temp_dir)
38+
39+
# Create a .env file
40+
env_content = f'FLYVIS_ROOT_DIR={temp_dir}'
41+
with open(temp_path / '.env', 'w') as f:
42+
f.write(env_content)
43+
44+
# Clear any existing environment variable
45+
if "FLYVIS_ROOT_DIR" in os.environ:
46+
del os.environ["FLYVIS_ROOT_DIR"]
47+
48+
# Change to temp directory and reload environment
49+
original_dir = os.getcwd()
50+
os.chdir(temp_dir)
51+
52+
try:
53+
import dotenv
54+
55+
dotenv.load_dotenv(dotenv.find_dotenv(usecwd=True))
56+
57+
# Get the resolved root directory
58+
root_dir = flyvis.resolve_root_dir()
59+
60+
assert root_dir == temp_path.absolute()
61+
62+
finally:
63+
# Restore original directory
64+
os.chdir(original_dir)
65+
66+
67+
def test_path_expansion():
68+
"""Test that user path expansion works correctly."""
69+
test_path = "~/flyvis_test"
70+
os.environ["FLYVIS_ROOT_DIR"] = test_path
71+
72+
root_dir = flyvis.resolve_root_dir()
73+
expected_dir = Path(test_path).expanduser().absolute()
74+
75+
assert root_dir == expected_dir

tests/test_sintel.py

+45-43
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,31 @@ def test_rendering(mock_sintel_data, tmp_path_factory):
3939

4040

4141
@pytest.fixture(scope="module")
42-
def dataset(mock_sintel_data):
43-
return MultiTaskSintel(
44-
tasks=["flow"],
45-
boxfilter=dict(extent=1, kernel_size=13),
46-
vertical_splits=3,
47-
n_frames=4,
48-
center_crop_fraction=0.7,
49-
dt=1 / 24,
50-
augment=True,
51-
random_temporal_crop=True,
52-
all_frames=False,
53-
resampling=True,
54-
interpolate=True,
55-
p_flip=0.5,
56-
p_rot=5 / 6,
57-
contrast_std=0.2,
58-
brightness_std=0.1,
59-
gaussian_white_noise=0.08,
60-
gamma_std=None,
61-
_init_cache=True,
62-
unittest=True,
63-
flip_axes=[0, 1, 2, 3],
64-
sintel_path=mock_sintel_data,
65-
)
42+
def dataset(mock_sintel_data, tmp_path_factory):
43+
with set_root_context(tmp_path_factory.mktemp("tmp")):
44+
return MultiTaskSintel(
45+
tasks=["flow"],
46+
boxfilter=dict(extent=1, kernel_size=13),
47+
vertical_splits=3,
48+
n_frames=4,
49+
center_crop_fraction=0.7,
50+
dt=1 / 24,
51+
augment=True,
52+
random_temporal_crop=True,
53+
all_frames=False,
54+
resampling=True,
55+
interpolate=True,
56+
p_flip=0.5,
57+
p_rot=5 / 6,
58+
contrast_std=0.2,
59+
brightness_std=0.1,
60+
gaussian_white_noise=0.08,
61+
gamma_std=None,
62+
_init_cache=True,
63+
unittest=True,
64+
flip_axes=[0, 1, 2, 3],
65+
sintel_path=mock_sintel_data,
66+
)
6667

6768

6869
@pytest.fixture(
@@ -81,25 +82,26 @@ def tasks(request):
8182
return request.param
8283

8384

84-
def test_init(tasks, mock_sintel_data):
85-
dataset = MultiTaskSintel(
86-
tasks=tasks, n_frames=4, unittest=True, sintel_path=mock_sintel_data
87-
)
88-
assert hasattr(dataset, "tasks")
89-
assert "lum" in dataset.data_keys
90-
assert hasattr(dataset, "config")
91-
assert hasattr(dataset, "meta")
92-
assert hasattr(dataset, "cached_sequences")
93-
assert hasattr(dataset, "arg_df")
94-
assert set(dataset[0].keys()) == set(["lum", *tasks])
95-
dataset = MultiTaskSintel(
96-
tasks=tasks,
97-
n_frames=4,
98-
unittest=True,
99-
_init_cache=False,
100-
sintel_path=mock_sintel_data,
101-
)
102-
assert not hasattr(dataset, "cached_sequences")
85+
def test_init(tasks, mock_sintel_data, tmp_path_factory):
86+
with set_root_context(tmp_path_factory.mktemp("tmp")):
87+
dataset = MultiTaskSintel(
88+
tasks=tasks, n_frames=4, unittest=True, sintel_path=mock_sintel_data
89+
)
90+
assert hasattr(dataset, "tasks")
91+
assert "lum" in dataset.data_keys
92+
assert hasattr(dataset, "config")
93+
assert hasattr(dataset, "meta")
94+
assert hasattr(dataset, "cached_sequences")
95+
assert hasattr(dataset, "arg_df")
96+
assert set(dataset[0].keys()) == set(["lum", *tasks])
97+
dataset = MultiTaskSintel(
98+
tasks=tasks,
99+
n_frames=4,
100+
unittest=True,
101+
_init_cache=False,
102+
sintel_path=mock_sintel_data,
103+
)
104+
assert not hasattr(dataset, "cached_sequences")
103105

104106

105107
def test_init_augmentation(dataset):

tests/test_solver.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def solver(mock_sintel_data, tmp_path_factory) -> MultiTaskSolver:
1515
"task.n_iters=50",
1616
f"+task.dataset.sintel_path={str(mock_sintel_data)}",
1717
"task.original_split=false",
18+
"task.dataset.boxfilter.extent=1",
19+
"task.dataset.n_frames=4",
20+
"task.dataset.dt=0.041",
21+
"task.batch_size=2",
22+
"network.connectome.extent=1",
1823
],
1924
)
2025
with set_root_context(str(tmp_path_factory.mktemp("tmp"))):
21-
solver = MultiTaskSolver("test", config)
22-
return solver
26+
return MultiTaskSolver("test", config)
2327

2428

2529
def test_solver_config():

tests/test_tasks.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from datamate import Namespace
1+
from datamate import Namespace, set_root_context
22

33
from flyvis.task.tasks import Task
44

55

6-
def test_task(mock_sintel_data, connectome):
6+
def test_task(mock_sintel_data, connectome, tmp_path_factory):
77
task_config = Namespace(
88
type="Task",
99
dataset=Namespace(
@@ -50,19 +50,19 @@ def test_task(mock_sintel_data, connectome):
5050
train_seq_ind=None,
5151
val_seq_ind=None,
5252
)
53+
with set_root_context(tmp_path_factory.mktemp("tmp")):
54+
task = Task(
55+
task_config.dataset,
56+
task_config.decoder,
57+
task_config.loss,
58+
batch_size=4,
59+
n_iters=250_000,
60+
n_folds=4,
61+
fold=1,
62+
seed=0,
63+
)
64+
assert task is not None
65+
assert task.dataset is not None
5366

54-
task = Task(
55-
task_config.dataset,
56-
task_config.decoder,
57-
task_config.loss,
58-
batch_size=4,
59-
n_iters=250_000,
60-
n_folds=4,
61-
fold=1,
62-
seed=0,
63-
)
64-
assert task is not None
65-
assert task.dataset is not None
66-
67-
decoder = task.init_decoder(connectome)
68-
assert decoder is not None
67+
decoder = task.init_decoder(connectome)
68+
assert decoder is not None

0 commit comments

Comments
 (0)