Skip to content

[Bug] compute_entropy_from_logits uses softmax(log_probs) instead of exp(log_probs), producing incorrect entropy #1386

@kuishou68

Description

@kuishou68

Bug Report

File: tunix/rl/ppo/ppo_helpers.py
Function: compute_entropy_from_logits
Line: 164

Description

The compute_entropy_from_logits function computes entropy incorrectly. On line 164 it applies jax.nn.softmax(log_probs) to convert log-probabilities to probabilities, but this is mathematically wrong.

softmax(log_softmax(x)) is not the same as exp(log_softmax(x)).

Current (broken) code:

def compute_entropy_from_logits(logits: jax.Array) -> jax.Array:
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    probs = jax.nn.softmax(log_probs)   # ← BUG: applies softmax to log_probs
    return -jnp.sum(probs * log_probs, axis=-1)

Expected (correct) code:

def compute_entropy_from_logits(logits: jax.Array) -> jax.Array:
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    probs = jnp.exp(log_probs)          # ← CORRECT: exp of log_probs = probs
    return -jnp.sum(probs * log_probs, axis=-1)

Why it matters

  • softmax(log_softmax(x)) re-applies normalization over already-log-normalized values, producing a different probability distribution than the original.
  • exp(log_softmax(x)) = softmax(x), which is the correct probability distribution.
  • The incorrect probs leads to silently wrong entropy values when training with PPO.
  • This function is called in ppo_learner.py line 654 to compute token_entropy, so incorrect entropy values affect PPO training metrics and potentially the training dynamics.

Minimal reproducible example

import jax
import jax.numpy as jnp

logits = jnp.array([2.0, 1.0, 0.1])

log_probs = jax.nn.log_softmax(logits)

# Current (wrong): softmax of log-probs
probs_wrong = jax.nn.softmax(log_probs)
# Correct: exp of log-probs
probs_correct = jnp.exp(log_probs)

print('probs_wrong  :', probs_wrong)    # different from probs_correct
print('probs_correct:', probs_correct)  # same as jax.nn.softmax(logits)
print('expected     :', jax.nn.softmax(logits))

Fix

Replace jax.nn.softmax(log_probs) with jnp.exp(log_probs) on line 164.

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions