-
Notifications
You must be signed in to change notification settings - Fork 312
Add real-time rendering callback #634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Modified for Viewer
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Added a bunch of comments
return | ||
io_callback(render_fn, None, state) | ||
|
||
jax.lax.cond(should_render, render, lambda s: None, nstate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this whole block just be
if render_fn:
io_callback(render_fn, None, state)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may get rid of the fixed overhead in your main post, JAX should be ignoring this whole block
eval_policy_fn(policy_params), | ||
key, | ||
unroll_length=episode_length // action_repeat, | ||
should_render=jnp.array(False, dtype=jnp.bool_), # No rendering during eval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're not passing render_fn
anyways, not sure you really need should_render
policy_params_fn: Callable[..., None] = lambda *args: None, | ||
# rendering | ||
render_fn: Optional[Callable[[envs.State], None]] = None, | ||
should_render: jax.Array = jnp.array(True, dtype=jnp.bool_), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype=bool
Returns: | ||
Tuple of (make_policy function, network params, metrics) | ||
""" | ||
# If the environment is wrapped with ViewerWrapper, use its rendering functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what ViewerWrapper is, maybe update the comment?
unroll_length, | ||
extra_fields=('truncation', 'episode_metrics', 'episode_done'), | ||
render_fn=render_fn, | ||
should_render=should_render, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, you maybe can get away without the bool should_render
# optimization | ||
|
||
# check for rendering dynamically | ||
should_render_py = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you're ignoring the arg to train()
, just delete all the args
This PR adds an interface for real-time rendering in
PPO.train
andacting
by adding a user-defined callback executed viajax.experimental.io_callback
. The final goal is to provide a real-time viewer beside the notebook html viewer.Key Changes
Two new optional parameters are added:
render_fn
: A Python callable that accepts abrax.State
to handle the rendering logic.should_render
: A boolean JAXArray
used to conditionally trigger the callback.Performance Impact
Adding the new callback introduces a minor performance. This overhead exists even when rendering is disabled (i.e.,
should_render
isFalse
). The JIT compiler must account for the conditional logic required for theio_callback
, which slightly alters the compiled execution path.The benchmarks below, run on an Apple M1 Max and NVIDIA 2080.