Skip to content

Commit

Permalink
fix squeeze under 1d action case
Browse files Browse the repository at this point in the history
  • Loading branch information
xunzhang committed Jan 23, 2020
1 parent 87eaa74 commit a34bc1e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions PPO_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def act(self, state, memory):
return action.detach()

def evaluate(self, state, action):
action_mean = torch.squeeze(self.actor(state))
action_mean = self.actor(state)

action_var = self.action_var.expand_as(action_mean)
cov_mat = torch.diag_embed(action_var).to(device)

dist = MultivariateNormal(action_mean, cov_mat)

action_logprobs = dist.log_prob(torch.squeeze(action))
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_value = self.critic(state)

Expand Down Expand Up @@ -109,9 +109,9 @@ def update(self, memory):
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

# convert list to tensor
old_states = torch.squeeze(torch.stack(memory.states).to(device)).detach()
old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach()
old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach()
old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach()
old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach()
old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach()

# Optimize policy for K epochs:
for _ in range(self.K_epochs):
Expand Down

0 comments on commit a34bc1e

Please sign in to comment.