diff --git a/PPO.py b/PPO.py index 3264eaa..b5d6d4c 100644 --- a/PPO.py +++ b/PPO.py @@ -119,13 +119,14 @@ def evaluate(self, state, action): if self.has_continuous_action_space: action_mean = self.actor(state) + action_var = self.action_var.expand_as(action_mean) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) else: action_probs = self.actor(state) dist = Categorical(action_probs) - + action = action.reshape(-1, self.action_dim) action_logprobs = dist.log_prob(action) dist_entropy = dist.entropy() state_values = self.critic(state)