Most JAX RL libraries treat memory as an afterthought, bolting an LSTM onto an existing agent and calling it done. Memorax makes memory a first-class citizen. It provides a composable set of sequence model primitives (attention, SSMs, linear RNNs, and more) that snap together into full architectures like GTrXL or xLSTM, paired with algorithms and replay buffers designed from the ground up for recurrent training. Whether you're benchmarking a new memory architecture on POMDPs or scaling recurrent agents across environments, Memorax gives you the building blocks to do it entirely in JAX.
| Details | |
|---|---|
| ๐ค Algorithms | DQN, PPO, SAC, PQN, MAPPO, R2D2 + memory-augmented variants with burn-in |
| ๐ Sequence Models | LSTM, GRU, xLSTM, FFM, SHM, S5, LRU, Mamba, MinGRU, Self-Attention, Linear Attention. Compose into GTrXL, GPT-2, and more. Support for RTRL |
| ๐งฌ Networks | ViT encoder. RoPE and ALiBi positional embeddings. MoE for horizontal scaling. RLยฒ wrapper for meta-RL. GVF/Horde heads. C51 and HL-Gauss distributional value heads. Composable feature extractor โ torso โ head pipeline |
| ๐ฎ Environments | Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, Grimax, XMiniGrid, JaxMARL |
| ๐ฆ Buffers | Pure JAX episode replay with prioritized sampling via Flashbax |
| ๐ Logging | CLI Dashboard, File, W&B, TensorboardX, Neptune |
Install Memorax using pip:
pip install memoraxOptionally you can add support for CUDA with:
pip install memorax[cuda]Optional: Set up Weights & Biases for logging by logging in:
wandb loginTrain a DQN agent on CartPole in under 30 lines:
import flax.linen as nn
import jax
import optax
from flashbax import make_item_buffer
from memorax.algorithms import DQN, DQNConfig
from memorax.environments import environment
from memorax.networks import FeatureExtractor, Network, heads
env, env_params = environment.make("gymnax::CartPole-v1")
cfg = DQNConfig(
num_envs=10, buffer_size=10_000,
tau=1.0, target_update_frequency=500, batch_size=64,
start_e=1.0, end_e=0.05, exploration_fraction=0.5, train_frequency=10,
)
q_network = Network(
feature_extractor=FeatureExtractor(observation_extractor=nn.Sequential((nn.Dense(120), nn.relu, nn.Dense(84), nn.relu))),
head=heads.DiscreteQNetwork(action_dim=env.action_space(env_params).n),
)
optimizer = optax.adam(3e-4)
buffer = make_item_buffer(max_length=cfg.buffer_size, min_length=cfg.batch_size,
sample_batch_size=cfg.batch_size, add_sequences=True, add_batches=True)
epsilon = optax.linear_schedule(cfg.start_e, cfg.end_e, 250_000, 10_000)
agent = DQN(cfg, env, env_params, q_network, optimizer, buffer, epsilon)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=500_000)See examples/ for complete scripts with logging and evaluation.
Memorax's real power is in its composable network primitives. Here's a PPO agent with a GTrXL-style architecture, built by snapping together modular blocks:
import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.networks import (
MLP, FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
PreNorm, SegmentRecurrence, SelfAttention, Stack, heads,
)
env, env_params = environment.make("gymnax::CartPole-v1")
cfg = PPOConfig(
num_envs=8,
num_steps=128,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coefficient=0.2,
clip_value_loss=True,
entropy_coefficient=0.01,
)
features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=MLP(features=(features,)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
SelfAttention(features, num_heads, context_length=128, positional_embedding=ALiBi(num_heads)),
memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)
actor_network = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic_network = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)
key, state = agent.init(jax.random.key(0))
key, state, transitions = agent.train(key, state, num_steps=10_000)See examples/architectures for more architecture compositions including xLSTM and GPT-2 style networks.
Memorax is designed to work alongside a growing suite of JAX-native tools focused on partial observability and memory. These projects provide the foundational architectures and benchmarks for modern memory-augmented RL:
- Memax: A library for efficient sequence and recurrent modeling in JAX. It provides unified interfaces for fast recurrent state resets and associative scans, serving as a powerful primitive for building memory architectures.
- Flashbax: The library powering
Memorax's buffer system. It provides high-performance, JAX-native experience replay buffers optimized for sequence storage and prioritized sampling. - Gymnax: The standard for JAX-native RL environments.
Memoraxprovides seamless wrappers to run recurrent agents on these vectorized tasks.
- PopGym Arcade: A JAX-native suite of "pixel-perfect" POMDP environments. It features Atari-style games specifically designed to test long-term memory with hardware-accelerated rendering.
- PopJym: A fast, JAX-native implementation of the POPGym benchmark suite, providing a variety of classic POMDP tasks optimized for massive vectorization.
- Navix: Accelerated MiniGrid-style environments. These are excellent for testing spatial reasoning and navigation in partially observable grid worlds.
- XLand-MiniGrid: A high-throughput meta-RL environment suite that provides massive task diversity for testing agent generalization in POMDPs.
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you use Memorax for your work, please cite:
@software{memorax2025github,
title = {Memorax: A Unified Framework for Memory-Augmented Reinforcement Learning},
author = {Noah Farr},
year = {2025},
url = {https://github.com/memory-rl/memorax}
}
Special thanks to @huterguier for the valuable discussions and advice on the API design.
