Skip to content

Conversation

bsyh
Copy link

@bsyh bsyh commented Aug 30, 2025

Note to maintainers: This is a draft PR related to #630. Feedback on the implementation is very welcome.

This PR adds an interface for real-time rendering in PPO.train and acting by adding a user-defined callback executed via jax.experimental.io_callback. The final goal is to provide a real-time viewer beside the notebook html viewer.

Key Changes

Two new optional parameters are added:

  • render_fn: A Python callable that accepts a brax.State to handle the rendering logic.
  • should_render: A boolean JAX Array used to conditionally trigger the callback.

Performance Impact

Adding the new callback introduces a minor performance. This overhead exists even when rendering is disabled (i.e., should_render is False). The JIT compiler must account for the conditional logic required for the io_callback, which slightly alters the compiled execution path.

The benchmarks below, run on an Apple M1 Max and NVIDIA 2080.

ppo_comparison_consolidated

Copy link

google-cla bot commented Aug 30, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@bsyh bsyh marked this pull request as ready for review September 2, 2025 16:25
@saikishor
Copy link

Copy link
Collaborator

@btaba btaba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Added a bunch of comments

return
io_callback(render_fn, None, state)

jax.lax.cond(should_render, render, lambda s: None, nstate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this whole block just be

if render_fn:
  io_callback(render_fn, None, state)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may get rid of the fixed overhead in your main post, JAX should be ignoring this whole block

eval_policy_fn(policy_params),
key,
unroll_length=episode_length // action_repeat,
should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're not passing render_fn anyways, not sure you really need should_render

policy_params_fn: Callable[..., None] = lambda *args: None,
# rendering
render_fn: Optional[Callable[[envs.State], None]] = None,
should_render: jax.Array = jnp.array(True, dtype=jnp.bool_),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=bool

Returns:
Tuple of (make_policy function, network params, metrics)
"""
# If the environment is wrapped with ViewerWrapper, use its rendering functions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what ViewerWrapper is, maybe update the comment?

unroll_length,
extra_fields=('truncation', 'episode_metrics', 'episode_done'),
render_fn=render_fn,
should_render=should_render,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, you maybe can get away without the bool should_render

# optimization

# check for rendering dynamically
should_render_py = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you're ignoring the arg to train(), just delete all the args

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants