Closed
Description
In the current implementation of our PPO training code (train_jax_ppo.py), there are two issues with the evaluation mode:
-
The
--play_only
flag doesn't actually skip training. While it's intended to "only play with the model and do not train" (as per the flag description), it only affects whether wandb/tensorboard logging is initialized. The training process still runs regardless of this flag. -
When loading a checkpoint via
--load_checkpoint_path
, the checkpoint is used as a starting point for further training rather than for pure evaluation. This means there's no way to truly evaluate a trained model without additional training.
Current behavior:
# Flag definition
_PLAY_ONLY = flags.DEFINE_boolean(
"play_only", False, "If true, only play with the model and do not train"
)
# Only affects logging
if _USE_WANDB.value and not _PLAY_ONLY.value:
wandb.init(...)
# Training happens regardless
make_inference_fn, params, _ = train_fn(
environment=env,
progress_fn=progress,
eval_env=eval_env,
)
Expected behavior:
- When
--play_only=True
is set, the code should skip training entirely - When a checkpoint is loaded with
--load_checkpoint_path
, it should be used for direct evaluation without additional training - Both behaviors should still allow proper evaluation and video generation
Questions:
- What's the best way to restructure the code to properly implement an evaluation-only mode?
- How should we handle the creation of
make_inference_fn
when skipping training?
Any suggestions or guidance would be greatly appreciated!
Metadata
Assignees
Labels
No labels
Activity