Skip to content

PPO evaluation mode train_jax_ppo.py not working as expected #58

Closed
@ChenDavidTimothy

Description

In the current implementation of our PPO training code (train_jax_ppo.py), there are two issues with the evaluation mode:

  1. The --play_only flag doesn't actually skip training. While it's intended to "only play with the model and do not train" (as per the flag description), it only affects whether wandb/tensorboard logging is initialized. The training process still runs regardless of this flag.

  2. When loading a checkpoint via --load_checkpoint_path, the checkpoint is used as a starting point for further training rather than for pure evaluation. This means there's no way to truly evaluate a trained model without additional training.

Current behavior:

# Flag definition
_PLAY_ONLY = flags.DEFINE_boolean(
    "play_only", False, "If true, only play with the model and do not train"
)

# Only affects logging
if _USE_WANDB.value and not _PLAY_ONLY.value:
    wandb.init(...)

# Training happens regardless
make_inference_fn, params, _ = train_fn(
    environment=env,
    progress_fn=progress,
    eval_env=eval_env,
)

Expected behavior:

  • When --play_only=True is set, the code should skip training entirely
  • When a checkpoint is loaded with --load_checkpoint_path, it should be used for direct evaluation without additional training
  • Both behaviors should still allow proper evaluation and video generation

Questions:

  1. What's the best way to restructure the code to properly implement an evaluation-only mode?
  2. How should we handle the creation of make_inference_fn when skipping training?

Any suggestions or guidance would be greatly appreciated!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions