Skip to content

Commit f31da71

Browse files
committed
Refactor CLI tests and fix solver config paths
- Update test_filter_args to handle command order preservation and Hydra-style args - Restructure test_main_argument_parsing with clearer expected behavior patterns - Fix config paths in solver tests to use package-relative paths - Add test cases for empty input and multiple command scenarios - Improve test readability with more descriptive variable names
1 parent 65ed916 commit f31da71

File tree

2 files changed

+62
-24
lines changed

2 files changed

+62
-24
lines changed

tests/test_cli.py

+60-22
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,47 @@
1212

1313
def test_filter_args():
1414
# Test filtering out commands from arguments
15-
argv = ["--ensemble_id", "0001", "train", "--task_name", "flow", "validate"]
16-
commands = ["train", "validate"]
17-
filtered = filter_args(argv, commands)
18-
assert filtered == ["--ensemble_id", "0001", "--task_name", "flow"]
15+
argv = [
16+
"--ensemble_id",
17+
"0001",
18+
"train",
19+
"--task_name",
20+
"flow",
21+
"validate",
22+
"task.n_iters=100",
23+
"foo:int=1",
24+
"record",
25+
]
26+
commands = ["train", "validate", "record"]
27+
selected_commands, filtered_args = filter_args(argv, commands)
28+
assert selected_commands == ["train", "validate", "record"]
29+
assert filtered_args == [
30+
"--ensemble_id",
31+
"0001",
32+
"--task_name",
33+
"flow",
34+
"task.n_iters=100",
35+
"foo:int=1",
36+
]
1937

2038
# Test with no commands to filter
2139
argv = ["--ensemble_id", "0001", "--task_name", "flow"]
2240
commands = ["train", "validate"]
23-
filtered = filter_args(argv, commands)
24-
assert filtered == argv
41+
selected_commands, filtered_args = filter_args(argv, commands)
42+
assert selected_commands == []
43+
assert filtered_args == argv
2544

2645
# Test with empty input
27-
assert filter_args([], []) == []
46+
selected_commands, filtered_args = filter_args([], [])
47+
assert selected_commands == []
48+
assert filtered_args == []
49+
50+
# Test order preservation
51+
argv = ["validate", "--ensemble_id", "0001", "train", "--task_name", "flow"]
52+
commands = ["train", "validate"]
53+
selected_commands, filtered_args = filter_args(argv, commands)
54+
assert selected_commands == ["validate", "train"] # Commands preserved in order found
55+
assert filtered_args == ["--ensemble_id", "0001", "--task_name", "flow"]
2856

2957

3058
def test_handle_help_request():
@@ -42,29 +70,39 @@ def test_handle_help_request():
4270

4371

4472
@pytest.mark.parametrize(
45-
"args,expected_return",
73+
"args,expected_behavior",
4674
[
47-
(["--ensemble_id", "0001", "--task_name", "flow", "train"], 0),
48-
(["--help"], 0), # Help message exits with 0
49-
([""], 2), # Invalid flag exits with 2
75+
(
76+
["--ensemble_id", "0001", "--task_name", "flow", "train"],
77+
{"should_succeed": True, "expected_command": "train"},
78+
),
79+
(
80+
["--help"],
81+
{"should_succeed": False, "return_code": 1},
82+
),
83+
(
84+
[],
85+
{"should_succeed": False, "return_code": 1},
86+
),
5087
],
5188
)
52-
def test_main_argument_parsing(args, expected_return):
89+
def test_main_argument_parsing(args, expected_behavior):
90+
"""Test CLI argument parsing behavior."""
5391
with (
5492
patch('sys.argv', ['flyvis'] + args),
5593
patch('flyvis_cli.flyvis_cli.run_script') as mock_run,
5694
):
57-
if expected_return != 0:
58-
with pytest.raises(SystemExit) as exc_info:
59-
main()
60-
assert exc_info.value.code == expected_return
95+
result = main()
96+
97+
if expected_behavior["should_succeed"]:
98+
assert result == 0
99+
# Verify the correct script was called
100+
mock_run.assert_called_once()
101+
script_path = mock_run.call_args[0][0]
102+
assert script_path.name == f"{expected_behavior['expected_command']}.py"
61103
else:
62-
try:
63-
assert main() == expected_return
64-
except SystemExit as exc:
65-
assert exc.code == expected_return
66-
if "--help" not in args:
67-
mock_run.assert_called()
104+
# Help and usage errors return 1
105+
assert result == expected_behavior["return_code"]
68106

69107

70108
def test_main_runs_multiple_commands():

tests/test_solver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@pytest.fixture(scope="module")
99
def solver(mock_sintel_data, tmp_path_factory) -> MultiTaskSolver:
1010
config = get_default_config(
11-
path="../../config/solver.yaml",
11+
path="../../flyvis/config/solver.yaml",
1212
overrides=[
1313
"task_name=flow",
1414
"ensemble_and_network_id=0",
@@ -28,7 +28,7 @@ def solver(mock_sintel_data, tmp_path_factory) -> MultiTaskSolver:
2828

2929
def test_solver_config():
3030
config = get_default_config(
31-
path="../../config/solver.yaml",
31+
path="../../flyvis/config/solver.yaml",
3232
overrides=[
3333
"task_name=flow",
3434
"ensemble_and_network_id=0",

0 commit comments

Comments
 (0)