Skip to content

Commit d58b986

Browse files
authored
[rllib] MultiCategorical shouldn't return array for kl or entropy (ray-project#5215)
* wip * fix
1 parent da7676c commit d58b986

File tree

4 files changed

+36
-11
lines changed

4 files changed

+36
-11
lines changed

ci/travis/format.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ YAPF_VERSION=$(yapf --version | awk '{print $2}')
2929
tool_version_check() {
3030
if [[ $2 != $3 ]]; then
3131
echo "WARNING: Ray uses $1 $3, You currently are using $2. This might generate different results."
32-
read -p "Do you want to continue? [y/n] " answer
33-
if ! [ $answer = 'y' ] && ! [ $answer = 'Y' ]; then
34-
exit 1
35-
fi
3632
fi
3733
}
3834

python/ray/rllib/agents/impala/vtrace_policy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def make_time_major(*args, **kw):
184184
actions=make_time_major(loss_actions, drop_last=True),
185185
actions_logp=make_time_major(
186186
action_dist.logp(actions), drop_last=True),
187-
actions_entropy=make_time_major(action_dist.entropy(), drop_last=True),
187+
actions_entropy=make_time_major(
188+
action_dist.multi_entropy(), drop_last=True),
188189
dones=make_time_major(dones, drop_last=True),
189190
behaviour_logits=make_time_major(
190191
unpacked_behaviour_logits, drop_last=True),

python/ray/rllib/agents/ppo/appo_policy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def make_time_major(*args, **kw):
205205
prev_action_dist.logp(actions), drop_last=True),
206206
actions_logp=make_time_major(
207207
action_dist.logp(actions), drop_last=True),
208-
action_kl=prev_action_dist.kl(action_dist),
208+
action_kl=prev_action_dist.multi_kl(action_dist),
209209
actions_entropy=make_time_major(
210-
action_dist.entropy(), drop_last=True),
210+
action_dist.multi_entropy(), drop_last=True),
211211
dones=make_time_major(dones, drop_last=True),
212212
behaviour_logits=make_time_major(
213213
unpacked_behaviour_logits, drop_last=True),
@@ -229,8 +229,8 @@ def make_time_major(*args, **kw):
229229
policy.loss = PPOSurrogateLoss(
230230
prev_actions_logp=make_time_major(prev_action_dist.logp(actions)),
231231
actions_logp=make_time_major(action_dist.logp(actions)),
232-
action_kl=prev_action_dist.kl(action_dist),
233-
actions_entropy=make_time_major(action_dist.entropy()),
232+
action_kl=prev_action_dist.multi_kl(action_dist),
233+
actions_entropy=make_time_major(action_dist.multi_entropy()),
234234
values=make_time_major(values),
235235
valid_mask=make_time_major(mask),
236236
advantages=make_time_major(

python/ray/rllib/models/action_dist.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ def sampled_action_prob(self):
6969
"""Returns the log probability of the sampled action."""
7070
return tf.exp(self.logp(self.sample_op))
7171

72+
def multi_kl(self, other):
73+
"""The KL-divergence between two action distributions.
74+
75+
This differs from kl() in that it can return an array for
76+
MultiDiscrete. TODO(ekl) consider removing this.
77+
"""
78+
return self.kl(other)
79+
80+
def multi_entropy(self):
81+
"""The entropy of the action distribution.
82+
83+
This differs from entropy() in that it can return an array for
84+
MultiDiscrete. TODO(ekl) consider removing this.
85+
"""
86+
return self.entropy()
87+
7288

7389
class Categorical(ActionDistribution):
7490
"""Categorical distribution for discrete action spaces."""
@@ -133,6 +149,7 @@ def __init__(self, inputs, input_lens):
133149
]
134150
self.sample_op = self._build_sample_op()
135151

152+
@override(ActionDistribution)
136153
def logp(self, actions):
137154
# If tensor is provided, unstack it into list
138155
if isinstance(actions, tf.Tensor):
@@ -141,12 +158,23 @@ def logp(self, actions):
141158
[cat.logp(act) for cat, act in zip(self.cats, actions)])
142159
return tf.reduce_sum(logps, axis=0)
143160

144-
def entropy(self):
161+
@override(ActionDistribution)
162+
def multi_entropy(self):
145163
return tf.stack([cat.entropy() for cat in self.cats], axis=1)
146164

147-
def kl(self, other):
165+
@override(ActionDistribution)
166+
def entropy(self):
167+
return tf.reduce_sum(self.multi_entropy(), axis=1)
168+
169+
@override(ActionDistribution)
170+
def multi_kl(self, other):
148171
return [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)]
149172

173+
@override(ActionDistribution)
174+
def kl(self, other):
175+
return tf.reduce_sum(self.multi_kl(other), axis=1)
176+
177+
@override(ActionDistribution)
150178
def _build_sample_op(self):
151179
return tf.stack([cat.sample() for cat in self.cats], axis=1)
152180

0 commit comments

Comments
 (0)