Skip to content

Commit

Permalink
logging losses
Browse files Browse the repository at this point in the history
  • Loading branch information
AOS55 committed Nov 6, 2022
1 parent fea7dae commit 78312e1
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions agents/unsupervised_learning/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ def update(self, replay_iter, step):
# TODO: Assumes obs is just (x, y) at front
p_star = self.get_goal_p_star(obs)
log_p_star = np.log(p_star)
log_p_star = 0.5 * torch.tensor(log_p_star).to(self.device)
log_p_star = -100.0 * torch.tensor(log_p_star).to(self.device)
# TODO: Check signs in this intrinsic reward function, maybe ask author
intr_reward = log_p_star + pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
# intr_reward = log_p_star + pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
intr_reward = pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
# print(f'intr_reward: {intr_reward[0]} = p*: {100 * log_p_star[0]} + rho_pi: {pred_log_ratios[0]} +h(z): {self.latent_ent_coef * h_z[0]} + h(z|s): {self.latent_cond_ent_coef * h_z_s.detach()[0]}')
reward = intr_reward
else:
Expand All @@ -300,6 +301,9 @@ def update(self, replay_iter, step):
metrics['pred_log_ratios'] = pred_log_ratios.mean().item()
metrics['latent_ent_coef'] = (self.latent_ent_coef * h_z).mean().item()
metrics['latent_cond_ent_coef'] = (self.latent_cond_ent_coef * h_z_s.detach()).mean().item()
# add loss values
metrics['loss_vae'] = vae_metrics['loss_vae']
metrics['loss_pred'] = pred_metrics['loss_pred']

if self.use_tb or self.use_wandb:
metrics.update(vae_metrics)
Expand Down

0 comments on commit 78312e1

Please sign in to comment.