Skip to content

Commit fd7adb5

Browse files
don't test evertime for tsc and rw+
1 parent fb3973e commit fd7adb5

File tree

2 files changed

+45
-38
lines changed

2 files changed

+45
-38
lines changed

src/curriculum/rw_curriculum.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def learn_online(self, planner, nn_model):
2121
difficulty = 4
2222
budget = self._initial_budget
2323
test_solve = 0
24+
prev_test_acc = 0
2425
memory = Memory()
2526
## TODO: remove this TMP!
2627

@@ -47,27 +48,31 @@ def learn_online(self, planner, nn_model):
4748
end-start)))
4849
results_file.write('\n')
4950

50-
51-
52-
test_sol_qual, test_solved, test_expanded, test_generated, _, _ = self.solve(self._test_set,\
53-
planner = planner, nn_model = nn_model, budget = self._test_budget, memory = memory, update = False)
54-
55-
self._test_solution_quality = test_sol_qual
56-
self._test_expansions = test_expanded
57-
58-
test_solve = test_solved / len(self._test_set)
59-
print('Iteration: {}\t Train solved: {}\t Test Solved:{}% Difficulty: {}'.format(
60-
iteration, number_solved / len(states) * 100, test_solve * 100, difficulty))
61-
6251
self._time.append(self._time[-1] + (end - start))
63-
self._performance.append(test_solve)
6452
self._expansions.append(self._expansions[-1] + total_expanded)
65-
if test_solved == 0:
66-
self._solution_quality.append(0)
67-
self._solution_expansions.append(0)
53+
54+
if prev_test_acc > 0.6 or (iteration - 1) % 5 == 0:
55+
test_sol_qual, test_solved, test_expanded, test_generated, _, _ = self.solve(self._test_set,\
56+
planner = planner, nn_model = nn_model, budget = self._test_budget, memory = memory, update = False)
57+
58+
self._test_solution_quality = test_sol_qual
59+
self._test_expansions = test_expanded
60+
61+
test_solve = test_solved / len(self._test_set)
62+
prev_test_acc = test_solve
63+
print('Iteration: {}\t Train solved: {}\t Test Solved:{}% Difficulty: {}'.format(
64+
iteration, number_solved / len(states) * 100, test_solve * 100, difficulty))
65+
66+
self._performance.append(test_solve)
67+
if test_solved == 0:
68+
self._solution_quality.append(0)
69+
self._solution_expansions.append(0)
70+
else:
71+
self._solution_quality.append(test_sol_qual / test_solved)
72+
self._solution_expansions.append(test_expanded / test_solved)
6873
else:
69-
self._solution_quality.append(test_sol_qual / test_solved)
70-
self._solution_expansions.append(test_expanded / test_solved)
74+
print('Iteration: {}\t Train solved: {} Difficulty: {}'.format(
75+
iteration, number_solved / len(states) * 100, difficulty))
7176

7277
if self.solvable(nn_model, number_solved, total_expanded, total_generated):
7378
difficulty += 1

src/curriculum/ts_curriculum.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ def learn_online(self, planner, nn_model):
2828
number_solved = 0
2929
total_expanded = 0
3030
total_generated = 0
31+
prev_test_acc = 0
3132
budget = self._initial_budget
3233
test_solve = 0
3334
memory = Memory()
3435

35-
teacher = CMAESTeacher(batch_size=self._states_per_difficulty, mean=4, std=4)
36+
teacher = CMAESTeacher(batch_size=self._states_per_difficulty, mean=4, std=100)
3637
## TODO: remove this TMP!
3738

3839
while test_solve < 1:
@@ -59,28 +60,29 @@ def learn_online(self, planner, nn_model):
5960
end-start)))
6061
results_file.write('\n')
6162

62-
63-
64-
test_sol_qual, test_solved, test_expanded, test_generated, _, _ = self.solve(self._test_set,
65-
planner=planner, nn_model=nn_model, budget=self._test_budget, memory=memory, update=False)
66-
67-
self._test_solution_quality = test_sol_qual
68-
self._test_expansions = test_expanded
69-
70-
test_solve = test_solved / len(self._test_set)
7163
mean_difficulty = sum(difficulties) / len(difficulties)
72-
print('Iteration: {}\t Train solved: {}\t Test Solved:{}% Mean Difficulty: {}'.format(
73-
iteration, number_solved / len(states) * 100, test_solve * 100, mean_difficulty))
74-
7564
self._time.append(self._time[-1] + (end - start))
76-
self._performance.append(test_solve)
7765
self._expansions.append(self._expansions[-1] + total_expanded)
78-
if test_solved == 0:
79-
self._solution_quality.append(0)
80-
self._solution_expansions.append(0)
81-
else:
82-
self._solution_quality.append(test_sol_qual / test_solved)
83-
self._solution_expansions.append(test_expanded / test_solved)
66+
67+
if prev_test_acc > 0.6 or (iteration - 1) % 5 == 0:
68+
test_sol_qual, test_solved, test_expanded, test_generated, _, _ = self.solve(self._test_set,
69+
planner=planner, nn_model=nn_model, budget=self._test_budget, memory=memory, update=False)
70+
71+
self._test_solution_quality = test_sol_qual
72+
self._test_expansions = test_expanded
73+
74+
test_solve = test_solved / len(self._test_set)
75+
prev_test_acc = test_solve
76+
print('Iteration: {}\t Train solved: {}\t Test Solved:{}% Mean Difficulty: {}'.format(
77+
iteration, number_solved / len(states) * 100, test_solve * 100, mean_difficulty))
78+
79+
self._performance.append(test_solve)
80+
if test_solved == 0:
81+
self._solution_quality.append(0)
82+
self._solution_expansions.append(0)
83+
else:
84+
self._solution_quality.append(test_sol_qual / test_solved)
85+
self._solution_expansions.append(test_expanded / test_solved)
8486

8587
#TODO: get rewards
8688
rewards = self.get_rewards(sol_expansions, difficulties) #TODO: make it expanded

0 commit comments

Comments
 (0)