A robust, fast PPO (Proximal Policy Optimization) implementation in Rust using the Burn ML library. Designed for discrete action spaces and board games.
- Fast GPU training via Burn's WGPU backend (Metal/Vulkan/CUDA auto-detection)
- TOML configuration with CLI overrides for experiments
- Checkpointing with
bestandlatestsymlinks for easy access - JSON-lines metrics with optional Aim streaming for visualization
- Vectorized environments for parallel rollout collection
- Two test environments: CartPole and Connect Four
# Build in release mode (much faster)
cargo build --release
# Train on CartPole (default)
cargo run --release
# Train with custom config
cargo run --release -- --config configs/default.toml --seed 123
# Override specific parameters
cargo run --release -- --learning-rate 0.0003 --num-envs 64By default, burn-ppo uses WGPU which auto-detects Metal (macOS), Vulkan (Linux/Windows), or DirectX. Alternative backends are available via feature flags:
# Default: WGPU (Metal/Vulkan/DirectX auto-detection)
cargo build --release
# CUDA backend (requires CUDA toolkit)
cargo build --release --features cuda
# LibTorch backend (requires libtorch)
cargo build --release --features libtorchBackend priority: cuda > libtorch > wgpu (default).
Default configuration in configs/default.toml:
env = "cartpole"
num_envs = "auto" # Scales to 2x CPU cores
num_steps = 128
learning_rate = 2.5e-4
gamma = 0.99
gae_lambda = 0.95
clip_epsilon = 0.2
entropy_coef = 0.01
total_timesteps = 1_000_000CLI overrides use kebab-case: --learning-rate, --num-envs, --total-timesteps.
Training logs to runs/<run_name>/metrics.jsonl. To visualize with Aim:
cd scripts
uv sync # Install Python dependencies
uv run aim init # Initialize Aim repo (once)
uv run aim up # Start Aim UI (http://localhost:43800)
# In another terminal
uv run aim_watcher.py ../runs/<run_name> # Stream metricsThe watcher tracks file offsets, so you can restart it without duplicate logs.
For detailed performance analysis, build with Tracy instrumentation:
# Build with Tracy profiling
cargo build --release --features tracy
# Run your training - Tracy will auto-connect
cargo run --release --features tracy -- --config configs/cartpole.tomlThen use the Tracy profiler GUI to connect and analyze:
- Frame timing for each training update
- Function-level timing for rollouts, GAE, PPO updates
- GPU/CPU data transfer costs
- Neural network forward/backward pass breakdown
Note: Building with Tracy requires a C++ compiler (Xcode Command Line Tools on macOS).
src/
main.rs # Training loop
config.rs # TOML + CLI configuration
network.rs # ActorCritic neural network
ppo.rs # PPO algorithm (GAE, clipped surrogate)
env.rs # Environment trait + VecEnv
envs/
cartpole.rs # CartPole test environment
connect_four.rs # Connect Four with self-play
checkpoint.rs # Model save/load
metrics.rs # JSON-lines logger
configs/ # TOML configuration files
scripts/ # Python Aim watcher
runs/ # Training outputs (per-run dirs)
docs/ # Design documentation
Checkpoints are saved to runs/<run_name>/checkpoints/:
step_00010000/- Checkpoint at step 10000latest -> step_00020000/- Symlink to most recentbest -> step_00015000/- Symlink to highest average return
Each checkpoint includes model weights, optimizer state, and training metadata (step count, returns history, etc.).
Continue training from the last checkpoint in an existing run:
cargo run --release -- --resume runs/<run_name>This loads the config from the run directory and continues where training left off. The global step, optimizer state, and metrics all continue from the checkpoint.
To train beyond the original total_timesteps:
cargo run --release -- --resume runs/<run_name> --total-timesteps 2000000Note: Only --total-timesteps can be overridden when resuming. Other config changes are ignored to preserve training consistency.
Create a new run starting from an existing checkpoint with different hyperparameters:
# Fork from best checkpoint with new learning rate
cargo run --release -- --fork runs/<run_name>/checkpoints/best \
--learning-rate 0.0001 --total-timesteps 500000
# Fork from a specific step
cargo run --release -- --fork runs/<run_name>/checkpoints/step_00050000 \
--learning-rate 0.0001Forking:
- Creates a new run directory
- Preserves the global step from the checkpoint (graphs continue from that point)
- Allows any config changes (learning rate, hyperparameters, etc.)
- Starts fresh metrics but step numbers continue from the checkpoint
Each training run creates:
runs/<run_name>/
config.toml # Frozen config snapshot
metrics.jsonl # Streaming metrics
checkpoints/ # Model checkpoints
Implements all core details from the ICLR blog:
- Orthogonal weight initialization
- Adam epsilon = 1e-5
- Learning rate linear annealing
- GAE (lambda=0.95)
- Advantage normalization at minibatch level
- Clipped surrogate objective + value clipping
- Global gradient clipping (max norm 0.5)
See docs/DESIGN.md for architecture decisions and extension points:
- Add new environments by implementing the
Environmenttrait - Modify reward shaping in the rollout collection loop
- Add auxiliary heads to the network
# Run tests
cargo test
# Check compilation
cargo check
# Build docs
cargo doc --openMIT