Skip to content

Commit 5710eb2

Browse files
wip for levi's idea
1 parent 5c6c678 commit 5710eb2

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/curriculum/curriculum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, num_states, model_name, state_generator, test_set_path,\
2828
self._time = [0]
2929
self._solution_quality = [0]
3030
self._solution_expansions = [0]
31+
self._traj = []
3132

3233

3334
self._log_folder = 'training_logs/'
@@ -73,6 +74,7 @@ def solve(self, states, planner, nn_model, budget, memory, update:bool):
7374
total_expanded += result[2]
7475
total_generated += result[3]
7576
puzzle_name = result[4]
77+
self._traj.append(trajectory)
7678

7779
if has_found_solution:
7880
#print(trajectory.get_solution_costs())
@@ -103,6 +105,11 @@ def solve(self, states, planner, nn_model, budget, memory, update:bool):
103105
batch_problems.clear()
104106
return (sum_sol_cost, number_solved, total_expanded, total_generated, sol_costs, sol_expansions)
105107

108+
def get_traj(self):
109+
traj = self._traj
110+
self._traj = []
111+
return traj
112+
106113
@abstractmethod
107114
def learn_online(self, planner, nn_model):
108115
pass

src/curriculum/lcb_curriculum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@ def learn_online(self, planner, nn_model):
7373
states[i], difficulty = self.generate_state(nn_model, base_state, new_budget)
7474
difficulties.append(difficulty)
7575
#print(states[i], difficulty)
76+
77+
self._traj = []
7678
_, number_solved, total_expanded, total_generated, sol_costs, sol_expansions = self.solve(states,
7779
planner=planner, nn_model=nn_model, budget=budget, memory=memory, update=True)
7880

81+
trajs = self.get_traj()
7982
staters_per_itr = states
8083
expansions_per_tr = sol_expansions
8184
end = time.time()

0 commit comments

Comments
 (0)