Bug Report
File: tunix/rl/ppo/ppo_helpers.py
Function: compute_entropy_from_logits
Line: 164
Description
The compute_entropy_from_logits function computes entropy incorrectly. On line 164 it applies jax.nn.softmax(log_probs) to convert log-probabilities to probabilities, but this is mathematically wrong.
softmax(log_softmax(x)) is not the same as exp(log_softmax(x)).
Current (broken) code:
def compute_entropy_from_logits(logits: jax.Array) -> jax.Array:
log_probs = jax.nn.log_softmax(logits, axis=-1)
probs = jax.nn.softmax(log_probs) # ← BUG: applies softmax to log_probs
return -jnp.sum(probs * log_probs, axis=-1)
Expected (correct) code:
def compute_entropy_from_logits(logits: jax.Array) -> jax.Array:
log_probs = jax.nn.log_softmax(logits, axis=-1)
probs = jnp.exp(log_probs) # ← CORRECT: exp of log_probs = probs
return -jnp.sum(probs * log_probs, axis=-1)
Why it matters
softmax(log_softmax(x)) re-applies normalization over already-log-normalized values, producing a different probability distribution than the original.
exp(log_softmax(x)) = softmax(x), which is the correct probability distribution.
- The incorrect
probs leads to silently wrong entropy values when training with PPO.
- This function is called in
ppo_learner.py line 654 to compute token_entropy, so incorrect entropy values affect PPO training metrics and potentially the training dynamics.
Minimal reproducible example
import jax
import jax.numpy as jnp
logits = jnp.array([2.0, 1.0, 0.1])
log_probs = jax.nn.log_softmax(logits)
# Current (wrong): softmax of log-probs
probs_wrong = jax.nn.softmax(log_probs)
# Correct: exp of log-probs
probs_correct = jnp.exp(log_probs)
print('probs_wrong :', probs_wrong) # different from probs_correct
print('probs_correct:', probs_correct) # same as jax.nn.softmax(logits)
print('expected :', jax.nn.softmax(logits))
Fix
Replace jax.nn.softmax(log_probs) with jnp.exp(log_probs) on line 164.
Bug Report
File:
tunix/rl/ppo/ppo_helpers.pyFunction:
compute_entropy_from_logitsLine: 164
Description
The
compute_entropy_from_logitsfunction computes entropy incorrectly. On line 164 it appliesjax.nn.softmax(log_probs)to convert log-probabilities to probabilities, but this is mathematically wrong.softmax(log_softmax(x))is not the same asexp(log_softmax(x)).Current (broken) code:
Expected (correct) code:
Why it matters
softmax(log_softmax(x))re-applies normalization over already-log-normalized values, producing a different probability distribution than the original.exp(log_softmax(x))=softmax(x), which is the correct probability distribution.probsleads to silently wrong entropy values when training with PPO.ppo_learner.pyline 654 to computetoken_entropy, so incorrect entropy values affect PPO training metrics and potentially the training dynamics.Minimal reproducible example
Fix
Replace
jax.nn.softmax(log_probs)withjnp.exp(log_probs)on line 164.