From 9cf68ff6cbd553cff8b37a7f203e7ebad631f6b4 Mon Sep 17 00:00:00 2001
From: p-christ
Date: Fri, 19 Jul 2019 10:33:11 +0100
Subject: [PATCH] fixed SAC discrete error that meant we werent taking the mean
of the actor loss
---
agents/actor_critic_agents/SAC_Discrete.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/agents/actor_critic_agents/SAC_Discrete.py b/agents/actor_critic_agents/SAC_Discrete.py
index 405b0090..c0577c78 100644
--- a/agents/actor_critic_agents/SAC_Discrete.py
+++ b/agents/actor_critic_agents/SAC_Discrete.py
@@ -81,7 +81,7 @@ def calculate_actor_loss(self, state_batch):
qf2_pi = self.critic_local_2(state_batch)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
inside_term = self.alpha * log_action_probabilities - min_qf_pi
- policy_loss = torch.sum(action_probabilities * inside_term)
+ policy_loss = action_probabilities * inside_term
policy_loss = policy_loss.mean()
log_action_probabilities = log_action_probabilities.gather(1, action.unsqueeze(-1).long())
return policy_loss, log_action_probabilities
\ No newline at end of file