Skip to content

fix: use jnp.exp(log_probs) instead of softmax(log_probs) in compute_entropy_from_logits#1387

Open
kuishou68 wants to merge 1 commit intogoogle:mainfrom
kuishou68:fix/issue-1386-entropy-softmax-vs-exp
Open

fix: use jnp.exp(log_probs) instead of softmax(log_probs) in compute_entropy_from_logits#1387
kuishou68 wants to merge 1 commit intogoogle:mainfrom
kuishou68:fix/issue-1386-entropy-softmax-vs-exp

Conversation

@kuishou68
Copy link
Copy Markdown

Summary

Fixes a bug in compute_entropy_from_logits in tunix/rl/ppo/ppo_helpers.py.

Closes #1386

Problem

Line 164 applies jax.nn.softmax(log_probs) to convert log-probabilities to probabilities. This is mathematically incorrect:

  • softmax(log_softmax(x))softmax(x)
  • Only exp(log_softmax(x)) = softmax(x) (i.e., the true probability distribution)

Applying softmax to already log-normalized values produces a different distribution, leading to silently incorrect entropy values during PPO training.

Fix

Replace jax.nn.softmax(log_probs) with jnp.exp(log_probs) on line 164:

# Before (wrong):
probs = jax.nn.softmax(log_probs)

# After (correct):
probs = jnp.exp(log_probs)

Impact

The compute_entropy_from_logits function is called in ppo_learner.py to compute token_entropy. The bug causes incorrect entropy values, which can silently affect PPO training metrics and dynamics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] compute_entropy_from_logits uses softmax(log_probs) instead of exp(log_probs), producing incorrect entropy

2 participants