Skip to content

Commit

Permalink
add goal_behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
AOS55 committed Oct 28, 2022
1 parent c2bfb72 commit 0341ca5
Showing 1 changed file with 45 additions and 8 deletions.
53 changes: 45 additions & 8 deletions agents/unsupervised_learning/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,18 @@ def __init__(self, z_dim, sp_lr, vae_lr, vae_beta, state_ent_coef,
self.update_encoder = update_encoder

kwargs["meta_dim"] = self.z_dim
#TODO: Fix this!
self.obs_type = kwargs["obs_type"]
super().__init__(**kwargs)
# self.obs_dim is now the real obs_dim (or repr_dim) + z_dim
self.smm = SMM(self.obs_dim - z_dim,
z_dim,
hidden_dim=kwargs['hidden_dim'],
vae_beta=vae_beta,
device=kwargs['device']).to(kwargs['device'])
self.pred_optimizer = torch.optim.Adam(
self.smm.z_pred_net.parameters(), lr=sp_lr)

self.goal = (150, 75) # TODO: Fix as part of config
self.pred_optimizer = torch.optim.Adam(self.smm.z_pred_net.parameters(), lr=sp_lr)
self.vae_optimizer = torch.optim.Adam(self.smm.vae.parameters(),
lr=vae_lr)

Expand Down Expand Up @@ -236,6 +239,21 @@ def update_pred(self, obs, z):

return metrics, h_z_s

def get_goal_p_star(self, agent_pos):
x_dist = agent_pos[:, 0] - self.goal[0]
y_dist = agent_pos[:, 1] - self.goal[1]
x_dist = x_dist.cpu().detach().numpy()
y_dist = y_dist.cpu().detach().numpy()
dist = np.linalg.norm((x_dist, y_dist), axis=0)
def _prior_distro(dist):
if dist > 1.0:
p_star = 1/dist
else:
p_star = 1.0
return p_star
p_star = np.array(list(map(_prior_distro, dist)), dtype=np.float32)
return p_star

def update(self, replay_iter, step):
metrics = dict()
if step % self.update_every_steps != 0:
Expand All @@ -244,7 +262,6 @@ def update(self, replay_iter, step):

obs, action, extr_reward, discount, next_obs, z = utils.to_torch(
batch, self.device)

obs = self.aug_and_encode(obs)
with torch.no_grad():
next_obs = self.aug_and_encode(next_obs)
Expand All @@ -258,14 +275,34 @@ def update(self, replay_iter, step):
h_z = np.log(self.z_dim) # One-hot z encoding
h_z *= torch.ones_like(extr_reward).to(self.device)

pred_log_ratios = self.state_ent_coef * h_s_z.detach(
) # p^*(s) is ignored, as state space dimension is inaccessible from pixel input
intr_reward = pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach(
)
reward = intr_reward
pred_log_ratios = self.state_ent_coef * h_s_z.detach()

if self.obs_type=='pixels':
# p^*(s) is ignored, as state space dimension is inaccessible from pixel input
intr_reward = pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach(
)
reward = intr_reward
else:
# p^*(s) is based on the goal hitting time
# TODO: Assumes obs is just (x, y) at front
p_star = self.get_goal_p_star(obs)
log_p_star = np.log(p_star)
log_p_star = 0.5 * torch.tensor(log_p_star).to(self.device)
# TODO: Check signs in this intrinsic reward function, maybe ask author
intr_reward = log_p_star + pred_log_ratios + self.latent_ent_coef * h_z + self.latent_cond_ent_coef * h_z_s.detach()
# print(f'intr_reward: {intr_reward[0]} = p*: {100 * log_p_star[0]} + rho_pi: {pred_log_ratios[0]} +h(z): {self.latent_ent_coef * h_z[0]} + h(z|s): {self.latent_cond_ent_coef * h_z_s.detach()[0]}')
reward = intr_reward
else:
reward = extr_reward

if self.obs_type=='states' and self.reward_free:
# add reward free to states motivation
metrics['intr_reward'] = intr_reward.mean().item()
metrics['log_p_star'] = log_p_star.mean().item()
metrics['pred_log_ratios'] = pred_log_ratios.mean().item()
metrics['latent_ent_coef'] = (self.latent_ent_coef * h_z).mean().item()
metrics['latent_cond_ent_coef'] = (self.latent_cond_ent_coef * h_z_s.detach()).mean().item()

if self.use_tb or self.use_wandb:
metrics.update(vae_metrics)
metrics.update(pred_metrics)
Expand Down

0 comments on commit 0341ca5

Please sign in to comment.