From a34bc1eeb5d374d1d4cf7070c0529a3b7312e996 Mon Sep 17 00:00:00 2001 From: xunzhang Date: Wed, 22 Jan 2020 18:25:14 -0800 Subject: [PATCH] fix squeeze under 1d action case --- PPO_continuous.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/PPO_continuous.py b/PPO_continuous.py index af292b0..e74139a 100644 --- a/PPO_continuous.py +++ b/PPO_continuous.py @@ -61,14 +61,14 @@ def act(self, state, memory): return action.detach() def evaluate(self, state, action): - action_mean = torch.squeeze(self.actor(state)) + action_mean = self.actor(state) action_var = self.action_var.expand_as(action_mean) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) - action_logprobs = dist.log_prob(torch.squeeze(action)) + action_logprobs = dist.log_prob(action) dist_entropy = dist.entropy() state_value = self.critic(state) @@ -109,9 +109,9 @@ def update(self, memory): rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5) # convert list to tensor - old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach() - old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach() - old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach() + old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach() + old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach() + old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach() # Optimize policy for K epochs: for _ in range(self.K_epochs):