Skip to content

A unified JAX framework for memory-augmented reinforcement learning with RNNs, SSMs, Transformers and more

License

Notifications You must be signed in to change notification settings

memory-rl/memorax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

598 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation


Memory-Augmented Reinforcement Learning in JAX ๐Ÿง 

Most JAX RL libraries treat memory as an afterthought, bolting an LSTM onto an existing agent and calling it done. Memorax makes memory a first-class citizen. It provides a composable set of sequence model primitives (attention, SSMs, linear RNNs, and more) that snap together into full architectures like GTrXL or xLSTM, paired with algorithms and replay buffers designed from the ground up for recurrent training. Whether you're benchmarking a new memory architecture on POMDPs or scaling recurrent agents across environments, Memorax gives you the building blocks to do it entirely in JAX.

โœจ Features

Details
๐Ÿค– Algorithms DQN, PPO, SAC, PQN, MAPPO, R2D2 + memory-augmented variants with burn-in
๐Ÿ” Sequence Models LSTM, GRU, xLSTM, FFM, SHM, S5, LRU, Mamba, MinGRU, Self-Attention, Linear Attention. Compose into GTrXL, GPT-2, and more. Support for RTRL
๐Ÿงฌ Networks ViT encoder. RoPE and ALiBi positional embeddings. MoE for horizontal scaling. RLยฒ wrapper for meta-RL. GVF/Horde heads. C51 and HL-Gauss distributional value heads. Composable feature extractor โ†’ torso โ†’ head pipeline
๐ŸŽฎ Environments Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, Grimax, XMiniGrid, JaxMARL
๐Ÿ“ฆ Buffers Pure JAX episode replay with prioritized sampling via Flashbax
๐Ÿ“Š Logging CLI Dashboard, File, W&B, TensorboardX, Neptune

๐Ÿ“ฅ Installation

Install Memorax using pip:

pip install memorax

Optionally you can add support for CUDA with:

pip install memorax[cuda]

Optional: Set up Weights & Biases for logging by logging in:

wandb login

๐Ÿš€ Quick Start

Train a DQN agent on CartPole in under 30 lines:

import flax.linen as nn
import jax
import optax
from flashbax import make_item_buffer
from memorax.algorithms import DQN, DQNConfig
from memorax.environments import environment
from memorax.networks import FeatureExtractor, Network, heads

env, env_params = environment.make("gymnax::CartPole-v1")

cfg = DQNConfig(
    num_envs=10, buffer_size=10_000,
    tau=1.0, target_update_frequency=500, batch_size=64,
    start_e=1.0, end_e=0.05, exploration_fraction=0.5, train_frequency=10,
)

q_network = Network(
    feature_extractor=FeatureExtractor(observation_extractor=nn.Sequential((nn.Dense(120), nn.relu, nn.Dense(84), nn.relu))),
    head=heads.DiscreteQNetwork(action_dim=env.action_space(env_params).n),
)

optimizer = optax.adam(3e-4)
buffer = make_item_buffer(max_length=cfg.buffer_size, min_length=cfg.batch_size,
                          sample_batch_size=cfg.batch_size, add_sequences=True, add_batches=True)
epsilon = optax.linear_schedule(cfg.start_e, cfg.end_e, 250_000, 10_000)

agent = DQN(cfg, env, env_params, q_network, optimizer, buffer, epsilon)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=500_000)

See examples/ for complete scripts with logging and evaluation.

๐Ÿ’ก Advanced Usage

Memorax's real power is in its composable network primitives. Here's a PPO agent with a GTrXL-style architecture, built by snapping together modular blocks:

import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
    MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
    PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)

env, env_params = environment.make("gymnax::CartPole-v1")

cfg = PPOConfig(
    num_envs=8,
    num_steps=128,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coefficient=0.2,
    clip_value_loss=True,
    entropy_coefficient=0.01,
)

features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
    SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
    memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)

actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))

agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)

See examples/architectures for more architecture compositions including xLSTM and GPT-2 style networks.

๐Ÿ“‚ Project Structure

``` memorax/ โ”œโ”€ examples/ # Small runnable scripts (e.g., DQN CartPole) โ”œโ”€ memorax/ โ”œโ”€ algorithms/ # DQN, PPO, SAC, PQN, ... โ”œโ”€ networks/ # MLP, CNN, ViT, RNN, heads, ... โ”œโ”€ environments/ # Gymnax / PopGym / Brax / ... โ”œโ”€ buffers/ # Custom flashbax buffers โ”œโ”€ loggers/ # CLI, WandB, TensorBoardX integrations โ””โ”€ utils/ ```

๐Ÿงฉ JAX POMDP Ecosystem

Memorax is designed to work alongside a growing suite of JAX-native tools focused on partial observability and memory. These projects provide the foundational architectures and benchmarks for modern memory-augmented RL:

๐Ÿง  Architectures & Infrastructure

  • Memax: A library for efficient sequence and recurrent modeling in JAX. It provides unified interfaces for fast recurrent state resets and associative scans, serving as a powerful primitive for building memory architectures.
  • Flashbax: The library powering Memorax's buffer system. It provides high-performance, JAX-native experience replay buffers optimized for sequence storage and prioritized sampling.
  • Gymnax: The standard for JAX-native RL environments. Memorax provides seamless wrappers to run recurrent agents on these vectorized tasks.

๐ŸŽฎ POMDP Benchmarks & Environments

  • PopGym Arcade: A JAX-native suite of "pixel-perfect" POMDP environments. It features Atari-style games specifically designed to test long-term memory with hardware-accelerated rendering.
  • PopJym: A fast, JAX-native implementation of the POPGym benchmark suite, providing a variety of classic POMDP tasks optimized for massive vectorization.
  • Navix: Accelerated MiniGrid-style environments. These are excellent for testing spatial reasoning and navigation in partially observable grid worlds.
  • XLand-MiniGrid: A high-throughput meta-RL environment suite that provides massive task diversity for testing agent generalization in POMDPs.

๐Ÿ“„ License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

๐Ÿ“š Citation

If you use Memorax for your work, please cite:

@software{memorax2025github,
  title   = {Memorax: A Unified Framework for Memory-Augmented Reinforcement Learning},
  author  = {Noah Farr},
  year    = {2025},
  url     = {https://github.com/memory-rl/memorax}
}

๐Ÿ™ Acknowledgments

Special thanks to @huterguier for the valuable discussions and advice on the API design.

About

A unified JAX framework for memory-augmented reinforcement learning with RNNs, SSMs, Transformers and more

Resources

License

Contributing

Stars

Watchers

Forks

Contributors 3

  •  
  •  
  •  

Languages