From 2d6a390c27e9996484b0cc91bad722ba73968925 Mon Sep 17 00:00:00 2001 From: Aakarshan-Chauhan <54583758+Aakarshan-Chauhan@users.noreply.github.com> Date: Sun, 11 Apr 2021 03:37:22 +0530 Subject: [PATCH] Fixed runtime error for continuous single action envs --- PPO.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/PPO.py b/PPO.py index b5d6d4c..a8e01fd 100644 --- a/PPO.py +++ b/PPO.py @@ -123,10 +123,12 @@ def evaluate(self, state, action): action_var = self.action_var.expand_as(action_mean) cov_mat = torch.diag_embed(action_var).to(device) dist = MultivariateNormal(action_mean, cov_mat) + + action = action.reshape(-1, self.action_dim) + else: action_probs = self.actor(state) dist = Categorical(action_probs) - action = action.reshape(-1, self.action_dim) action_logprobs = dist.log_prob(action) dist_entropy = dist.entropy() state_values = self.critic(state)