Skip to content

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

License

Notifications You must be signed in to change notification settings

tinker495/jax-baseline

Repository files navigation

Jax-Baseline

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Features

  • 2-3 times faster than previous Torch and Tensorflow implementations
  • Optimized using JAX's Just-In-Time (JIT) compilation

Installation

pip install -r requirement.txt
pip install .

Implementation Status

  • ✔️ : Optional implemented
  • ✅ : Defualt implemented at papers
  • ❌ : Not implemeted yet or can not implemented

Supported Environments

Name Q-Net based Actor-Critic based DPG based
Gymnasium ✔️ ✔️ ✔️
VectorizedGym with Ray ✔️ ✔️ ✔️

Implemented Algorithms

Q-Net bases

Name Double1 Dueling2 Per3 N-step45 NoisyNet6 Munchausen7 Ape-X8 HL-Gauss9
DQN10 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
C5111 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
QRDQN12 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
IQN13 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
FQF14 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
SPR15 ✔️ ✔️
BBF16 ✔️ ✔️ ✔️

Actor-Critic based

Name Box Discrete IMPALA17
A2C18 ✔️ ✔️ ✔️
PPO19 ✔️ ✔️ ✔️20
Truly PPO(TPPO)21 ✔️ ✔️
SPO22 ✔️ ✔️

DPG bases

Name Per3 N-step45 Ape-X8 Simba23 Simba-v224
DDPG25 ✔️ ✔️ ✔️ ✔️ ✔️
TD326 ✔️ ✔️ ✔️ ✔️ ✔️
SAC27 ✔️ ✔️ ✔️ ✔️
DAC28
TQC29 ✔️ ✔️ ✔️ ✔️
TD730 ✅(LAP31) ✔️ ✔️
CrossQ32 ✔️ ✔️ ✔️ ✔️
BRO33

Performance Compariton

Test

To test Atari with DQN (or C51, QRDQN, IQN, FQF):

python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
		--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
		--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards

500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec). Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.

score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]

Footnotes

  1. Double DQN paper

  2. Dueling DQN paper

  3. PER 2

  4. N-step TD 2

  5. RAINBOW DQN 2

  6. Noisy network

  7. Munchausen rl

  8. Ape-X 2

  9. HL-GAUSS

  10. DQN

  11. C51

  12. QRDQN

  13. IQN

  14. FQF

  15. SPR

  16. BBF

  17. IMPALA

  18. A3C

  19. PPO

  20. IMPALA + PPO, APPO

  21. Truly PPO

  22. SPO

  23. SIMBA

  24. SIMBAv2

  25. DDPG

  26. TD3

  27. SAC

  28. DAC

  29. TQC

  30. TD7

  31. LaP

  32. CrossQ

  33. BRO

About

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published