Skip to content

Commit d77ae3d

Browse files
committed
Recursively parse default env functions
1 parent 275f9b6 commit d77ae3d

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

python/ecole/src/ecole/data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def parse(something, default):
2323
2424
"""
2525
if something == "default":
26-
return default
26+
if default is None:
27+
raise ValueError("""Cannot parse "default" without a default value.""")
28+
return parse(default, None)
2729
elif something is None:
2830
return NoneFunction()
2931
elif isinstance(something, numbers.Number):

python/ecole/tests/test_data.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ def test_MapFunction(model, done):
5151
assert data == {"name1": "something", "name2": "else"}
5252

5353

54+
@pytest.mark.parametrize("done", (True, False))
55+
@pytest.mark.parametrize("wall", (True, False))
56+
def test_TimedFunction(model, done, wall):
57+
"""Time a given data function."""
58+
data_func = mock.MagicMock()
59+
time_data_func = ecole.data.TimedFunction(data_func, wall=wall)
60+
61+
time_data_func.before_reset(model)
62+
data_func.before_reset.assert_called_once_with(model)
63+
64+
pytest.helpers.advance_to_stage(model, ecole.scip.Stage.Solving)
65+
time = time_data_func.extract(model, done)
66+
assert time > 0
67+
68+
5469
def test_parse_None():
5570
"""None is parsed as NoneFunction."""
5671
assert isinstance(ecole.data.parse(None, mock.MagicMock()), ecole.data.NoneFunction)
@@ -62,6 +77,12 @@ def test_parse_default():
6277
assert ecole.data.parse("default", default_func) == default_func
6378

6479

80+
def test_parse_self_reference():
81+
"""Default can not be used in the default function."""
82+
with pytest.raises(ValueError):
83+
ecole.data.parse("default", "default")
84+
85+
6586
def test_parse_number():
6687
"""Number return ConstantFunction."""
6788
assert isinstance(ecole.data.parse(1, mock.MagicMock()), ecole.data.ConstantFunction)
@@ -98,16 +119,16 @@ def test_parse_recursive(model):
98119
assert data["name3"] == default_func.extract.return_value
99120

100121

101-
@pytest.mark.parametrize("done", (True, False))
102-
@pytest.mark.parametrize("wall", (True, False))
103-
def test_MapFunction(model, done, wall):
104-
"""Time a given data function."""
105-
data_func = mock.MagicMock()
106-
time_data_func = ecole.data.TimedFunction(data_func, wall=wall)
107-
108-
time_data_func.before_reset(model)
109-
data_func.before_reset.assert_called_once_with(model)
110-
111-
pytest.helpers.advance_to_stage(model, ecole.scip.Stage.Solving)
112-
time = time_data_func.extract(model, done)
113-
assert time > 0
122+
def test_parse_recursive_default(model):
123+
"""Default function is parsed as well."""
124+
aggregate = {
125+
"name1": mock.MagicMock(),
126+
"name2": (mock.MagicMock(), None, 1),
127+
}
128+
func = ecole.data.parse("default", aggregate)
129+
# Using the extract method to inspect the recusive parsing since Vector, Map, Constant functions are private.
130+
data = func.extract(model, False)
131+
assert isinstance(data, dict)
132+
assert isinstance(data["name2"], list)
133+
assert data["name2"][1] is None
134+
assert data["name2"][2] == 1

0 commit comments

Comments
 (0)