@@ -51,6 +51,21 @@ def test_MapFunction(model, done):
5151 assert data == {"name1" : "something" , "name2" : "else" }
5252
5353
54+ @pytest .mark .parametrize ("done" , (True , False ))
55+ @pytest .mark .parametrize ("wall" , (True , False ))
56+ def test_TimedFunction (model , done , wall ):
57+ """Time a given data function."""
58+ data_func = mock .MagicMock ()
59+ time_data_func = ecole .data .TimedFunction (data_func , wall = wall )
60+
61+ time_data_func .before_reset (model )
62+ data_func .before_reset .assert_called_once_with (model )
63+
64+ pytest .helpers .advance_to_stage (model , ecole .scip .Stage .Solving )
65+ time = time_data_func .extract (model , done )
66+ assert time > 0
67+
68+
5469def test_parse_None ():
5570 """None is parsed as NoneFunction."""
5671 assert isinstance (ecole .data .parse (None , mock .MagicMock ()), ecole .data .NoneFunction )
@@ -62,6 +77,12 @@ def test_parse_default():
6277 assert ecole .data .parse ("default" , default_func ) == default_func
6378
6479
80+ def test_parse_self_reference ():
81+ """Default can not be used in the default function."""
82+ with pytest .raises (ValueError ):
83+ ecole .data .parse ("default" , "default" )
84+
85+
6586def test_parse_number ():
6687 """Number return ConstantFunction."""
6788 assert isinstance (ecole .data .parse (1 , mock .MagicMock ()), ecole .data .ConstantFunction )
@@ -98,16 +119,16 @@ def test_parse_recursive(model):
98119 assert data ["name3" ] == default_func .extract .return_value
99120
100121
101- @ pytest . mark . parametrize ( "done" , ( True , False ))
102- @ pytest . mark . parametrize ( "wall" , ( True , False ))
103- def test_MapFunction ( model , done , wall ):
104- """Time a given data function."""
105- data_func = mock .MagicMock ()
106- time_data_func = ecole . data . TimedFunction ( data_func , wall = wall )
107-
108- time_data_func . before_reset ( model )
109- data_func . before_reset . assert_called_once_with (model )
110-
111- pytest . helpers . advance_to_stage ( model , ecole . scip . Stage . Solving )
112- time = time_data_func . extract ( model , done )
113- assert time > 0
122+ def test_parse_recursive_default ( model ):
123+ """Default function is parsed as well."""
124+ aggregate = {
125+ "name1" : mock . MagicMock (),
126+ "name2" : ( mock .MagicMock (), None , 1 ),
127+ }
128+ func = ecole . data . parse ( "default" , aggregate )
129+ # Using the extract method to inspect the recusive parsing since Vector, Map, Constant functions are private.
130+ data = func . extract (model , False )
131+ assert isinstance ( data , dict )
132+ assert isinstance ( data [ "name2" ], list )
133+ assert data [ "name2" ][ 1 ] is None
134+ assert data [ "name2" ][ 2 ] == 1
0 commit comments