Skip to content

Commit 200ab7b

Browse files
author
Ervin T
authored
[bug-fix] Fix entropy computation in MultiCategorialDistribution (#3607)
1 parent cad085f commit 200ab7b

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

ml-agents/mlagents/trainers/distributions.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,13 @@ def __init__(self, logits: tf.Tensor, act_size: List[int], action_masks: tf.Tens
189189
and 1 for unmasked.
190190
"""
191191
unmasked_log_probs = self._create_policy_branches(logits, act_size)
192-
self._sampled_policy, self._all_probs, action_index = self._get_masked_actions_probs(
193-
unmasked_log_probs, act_size, action_masks
194-
)
192+
(
193+
self._sampled_policy,
194+
self._all_probs,
195+
action_index,
196+
) = self._get_masked_actions_probs(unmasked_log_probs, act_size, action_masks)
195197
self._sampled_onehot = self._action_onehot(self._sampled_policy, act_size)
196-
self._entropy = self._create_entropy(
197-
self._sampled_onehot, self._all_probs, action_index, act_size
198-
)
198+
self._entropy = self._create_entropy(self._all_probs, action_index, act_size)
199199
self._total_prob = self._get_log_probs(
200200
self._sampled_onehot, self._all_probs, action_index, act_size
201201
)
@@ -263,11 +263,7 @@ def _get_log_probs(
263263
return log_probs
264264

265265
def _create_entropy(
266-
self,
267-
all_log_probs: tf.Tensor,
268-
sample_onehot: tf.Tensor,
269-
action_idx: List[int],
270-
act_size: List[int],
266+
self, all_log_probs: tf.Tensor, action_idx: List[int], act_size: List[int]
271267
) -> tf.Tensor:
272268
entropy = tf.reduce_sum(
273269
(

ml-agents/mlagents/trainers/tests/test_distributions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def test_multicategorical_distribution():
113113
sess.run(init)
114114
output = sess.run(distribution.sample)
115115
for _ in range(10):
116-
sample, log_probs = sess.run(
117-
[distribution.sample, distribution.log_probs]
116+
sample, log_probs, entropy = sess.run(
117+
[distribution.sample, distribution.log_probs, distribution.entropy]
118118
)
119119
assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE)
120120
# Assert action never exceeds [-1,1]
@@ -123,6 +123,8 @@ def test_multicategorical_distribution():
123123
assert act >= 0 and act <= DISCRETE_ACTION_SPACE[i]
124124
output = sess.run([distribution.total_log_probs])
125125
assert output[0].shape[0] == 1
126+
# Make sure entropy is correct
127+
assert entropy[0] > 3.8
126128

127129
# Test masks
128130
mask = []

0 commit comments

Comments
 (0)