@@ -28,11 +28,12 @@ def learn_online(self, planner, nn_model):
28
28
number_solved = 0
29
29
total_expanded = 0
30
30
total_generated = 0
31
+ prev_test_acc = 0
31
32
budget = self ._initial_budget
32
33
test_solve = 0
33
34
memory = Memory ()
34
35
35
- teacher = CMAESTeacher (batch_size = self ._states_per_difficulty , mean = 4 , std = 4 )
36
+ teacher = CMAESTeacher (batch_size = self ._states_per_difficulty , mean = 4 , std = 100 )
36
37
## TODO: remove this TMP!
37
38
38
39
while test_solve < 1 :
@@ -59,28 +60,29 @@ def learn_online(self, planner, nn_model):
59
60
end - start )))
60
61
results_file .write ('\n ' )
61
62
62
-
63
-
64
- test_sol_qual , test_solved , test_expanded , test_generated , _ , _ = self .solve (self ._test_set ,
65
- planner = planner , nn_model = nn_model , budget = self ._test_budget , memory = memory , update = False )
66
-
67
- self ._test_solution_quality = test_sol_qual
68
- self ._test_expansions = test_expanded
69
-
70
- test_solve = test_solved / len (self ._test_set )
71
63
mean_difficulty = sum (difficulties ) / len (difficulties )
72
- print ('Iteration: {}\t Train solved: {}\t Test Solved:{}% Mean Difficulty: {}' .format (
73
- iteration , number_solved / len (states ) * 100 , test_solve * 100 , mean_difficulty ))
74
-
75
64
self ._time .append (self ._time [- 1 ] + (end - start ))
76
- self ._performance .append (test_solve )
77
65
self ._expansions .append (self ._expansions [- 1 ] + total_expanded )
78
- if test_solved == 0 :
79
- self ._solution_quality .append (0 )
80
- self ._solution_expansions .append (0 )
81
- else :
82
- self ._solution_quality .append (test_sol_qual / test_solved )
83
- self ._solution_expansions .append (test_expanded / test_solved )
66
+
67
+ if prev_test_acc > 0.6 or (iteration - 1 ) % 5 == 0 :
68
+ test_sol_qual , test_solved , test_expanded , test_generated , _ , _ = self .solve (self ._test_set ,
69
+ planner = planner , nn_model = nn_model , budget = self ._test_budget , memory = memory , update = False )
70
+
71
+ self ._test_solution_quality = test_sol_qual
72
+ self ._test_expansions = test_expanded
73
+
74
+ test_solve = test_solved / len (self ._test_set )
75
+ prev_test_acc = test_solve
76
+ print ('Iteration: {}\t Train solved: {}\t Test Solved:{}% Mean Difficulty: {}' .format (
77
+ iteration , number_solved / len (states ) * 100 , test_solve * 100 , mean_difficulty ))
78
+
79
+ self ._performance .append (test_solve )
80
+ if test_solved == 0 :
81
+ self ._solution_quality .append (0 )
82
+ self ._solution_expansions .append (0 )
83
+ else :
84
+ self ._solution_quality .append (test_sol_qual / test_solved )
85
+ self ._solution_expansions .append (test_expanded / test_solved )
84
86
85
87
#TODO: get rewards
86
88
rewards = self .get_rewards (sol_expansions , difficulties ) #TODO: make it expanded
0 commit comments