Skip to content

Commit

Permalink
fix epoch count, add timetep logging
Browse files Browse the repository at this point in the history
  • Loading branch information
StephAO committed Sep 16, 2022
1 parent 07f27d8 commit 5d7e411
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions agents/rl_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ def train_agents(self, total_timesteps=1e8, exp_name=None):
best_path, best_tag = None, None
while self.agents[self.p_idx].agent.num_timesteps < total_timesteps:
if epoch % 10 == 0:
epoch += 1
score, done_training = self.eval_env.evaluate(self.agents[self.p_idx], num_trials=1)
scores.append(score)
print(f'Episode eval at epoch {epoch}: {score}')
wandb.log({'eval_true_score': score, 'epoch': epoch})
wandb.log({'eval_true_score': score, 'epoch': epoch, 'timestep': self.agents[self.p_idx].agent.num_timesteps})
if score > best_score:
best_path, best_tag = self.save()
best_score = score
Expand All @@ -123,6 +122,7 @@ def train_agents(self, total_timesteps=1e8, exp_name=None):
self.agents[self.p_idx].learn(total_timesteps=EPOCH_TIMESTEPS)
for i in range(self.args.n_envs):
self.env.env_method('set_agent', self.teammates[np.random.randint(self.n_tm)], self.t_idx, indices=i)
epoch += 1

if best_path is not None:
self.load(best_path, best_tag)
Expand Down Expand Up @@ -178,11 +178,10 @@ def train_agents(self, total_timesteps=1e8, exp_name=None):
best_path, best_tag = None, None
while self.agents[0].agent.num_timesteps < total_timesteps:
if epoch % 10 == 0:
epoch += 1
score = self.eval_env.run_full_episode()
scores.append(score)
print(f'Episode eval at epoch {epoch}: {score}')
wandb.log({'eval_true_reward': score, 'epoch': epoch})
wandb.log({'eval_true_reward': score, 'epoch': epoch, 'timestep': self.agents[0].agent.num_timesteps})
if score > best_score:
best_path, best_tag = self.save()
best_score = score
Expand All @@ -191,6 +190,7 @@ def train_agents(self, total_timesteps=1e8, exp_name=None):
self.ck_list.append( (score, path, tag) )
self.agents[0].learn(total_timesteps=EPOCH_TIMESTEPS)
self.agents[1].learn(total_timesteps=EPOCH_TIMESTEPS)
epoch += 1

if best_path is not None:
self.load(best_path, best_tag)
Expand Down Expand Up @@ -261,15 +261,15 @@ def train_agents(self, total_timesteps=1e8, exp_name=None):
best_path, best_tag = None, None
while self.agents[0].agent.num_timesteps < total_timesteps:
if epoch % 10 == 0:
epoch += 1
score = self.run_full_episode()
scores.append(score)
print(f'Episode eval at epoch {epoch}: {score}')
wandb.log({'eval_true_reward': score, 'epoch': epoch})
wandb.log({'eval_true_reward': score, 'epoch': epoch, 'timestep': self.agents[0].agent.num_timesteps})
if score > best_score:
best_path, best_tag = self.save()
best_score = score
self.agents[0].learn(total_timesteps=EPOCH_TIMESTEPS)
epoch += 1

if best_path is not None:
self.load(best_path, best_tag)
Expand Down

0 comments on commit 5d7e411

Please sign in to comment.