Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 67c26e4

Browse files
afrozenatorcopybara-github
authored andcommitted
Probabs while sampling actions from policy still aren't normalized sometimes, very close, but the numpy check fails, so normalize them.
(Maybe do all that in JAX?) PiperOrigin-RevId: 246932988
1 parent ccbe132 commit 67c26e4

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensor2tensor/envs/env_problem_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,26 @@ def multinomial_sample(probs):
106106
log_probs = log_prob_actions[np.arange(B)[:, None],
107107
index[:, None],
108108
np.arange(A)]
109-
assert (B, A) == log_probs.shape
109+
assert (B, A) == log_probs.shape, \
110+
"B=%d, A=%d, log_probs.shape=%s" % (B, A, log_probs.shape)
110111

111112
# Convert to probs, since we need to do categorical sampling.
112113
probs = np.exp(log_probs)
113114

114115
# Sometimes log_probs contains a 0, it shouldn't. This makes the
115116
# probabilities sum up to more than 1, since the addition happens
116117
# in float64, so just add and subtract 1.0 to zero those probabilites
117-
# out. Real example encountered probs = [1e-8, 1.0, 1e-22]
118+
# out.
118119
#
119120
# Also testing for this is brittle.
120121
probs += 1
121122
probs -= 1
122123

124+
# For some reason, sometimes, this isn't the case.
125+
probs_sum = np.sum(probs, axis=1, keepdims=True)
126+
if not all(probs_sum == 1.0):
127+
probs = probs / probs_sum
128+
123129
# Now pick actions from this probs array.
124130
actions = np.apply_along_axis(multinomial_sample, 1, probs)
125131

0 commit comments

Comments
 (0)