Skip to content

Commit 4759d1f

Browse files
author
Ervin T
authored
[add-fire] Halve Gaussian entropy (#4319)
* Halve entropy * Fix utils test
1 parent ace4394 commit 4759d1f

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

ml-agents/mlagents/trainers/tests/torch/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def test_gaussian_dist_instance():
105105
assert log_prob == pytest.approx(-0.919, abs=0.01)
106106

107107
for ent in dist_instance.entropy().flatten():
108-
# entropy of standard normal at 0
109-
assert ent == pytest.approx(2.83, abs=0.01)
108+
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma)
109+
assert ent == pytest.approx(1.42, abs=0.01)
110110

111111

112112
def test_tanh_gaussian_dist_instance():

ml-agents/mlagents/trainers/tests/torch/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_get_probs_and_entropy():
149149

150150
for ent in entropies.flatten():
151151
# entropy of standard normal at 0
152-
assert ent == pytest.approx(2.83, abs=0.01)
152+
assert ent == pytest.approx(1.42, abs=0.01)
153153

154154
# Test continuous
155155
# Add two dists to the list.

ml-agents/mlagents/trainers/torch/distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def pdf(self, value):
6666
return torch.exp(log_prob)
6767

6868
def entropy(self):
69-
return torch.log(2 * math.pi * math.e * self.std + EPSILON)
69+
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
7070

7171

7272
class TanhGaussianDistInstance(GaussianDistInstance):

0 commit comments

Comments
 (0)