A JAX-native multi-agent grid world reinforcement learning library
- JAX-Native — All operations are pure JAX, fully JIT-compilable for maximum performance
- Multi-Agent First — Designed from the ground up for multi-agent reinforcement learning (MARL)
- Functional API — Explicit state management with no hidden side effects
- Vectorizable — Full
vmapsupport for parallel environment rollouts - Fair Collision Resolution — Random tie-breaking for agent collisions (not index-priority)
- Beautiful Rendering — Sprite-based renderer using the 0x72 DungeonTileset II
# Basic installation
pip install grimax
# With GIF rendering support
pip install "grimax[gif]"
# Development installation
pip install "grimax[dev]"git clone https://github.com/noahfarr/grimax.git
cd grimax
pip install -e .import jax
import grimax
# Create a multi-agent environment
env = grimax.make("CooperativeButtonPress-v1")
# Initialize
key = jax.random.PRNGKey(42)
state, timestep = env.init(key)
# Run a step with actions for all agents
key, step_key = jax.random.split(key)
actions = jax.numpy.array([0, 1, 2]) # One action per agent
state, timestep = env.step(step_key, state, actions)
# Access results
observation = timestep.observation # (num_agents, view_size, view_size, 3)
reward = timestep.reward # (num_agents,)
done = timestep.done # booleanGrimax is designed for high-performance MARL research. Both init and step are fully JIT-compilable:
import jax
import grimax
env = grimax.make("CooperativeButtonPress-v1")
# JIT compile for faster execution
@jax.jit
def run_episode(key):
state, timestep = env.init(key)
def step_fn(carry, _):
state, key = carry
key, step_key, action_key = jax.random.split(key, 3)
# Random actions for all agents
actions = jax.random.randint(action_key, (3,), 0, 4)
state, timestep = env.step(step_key, state, actions)
return (state, key), timestep.reward.sum()
(final_state, _), rewards = jax.lax.scan(step_fn, (state, key), None, length=100)
return rewards.sum()
# Vectorize across multiple environments
@jax.jit
@jax.vmap
def batch_init(keys):
return env.init(keys)
keys = jax.random.split(jax.random.PRNGKey(0), 1024)
states, timesteps = batch_init(keys) # 1024 parallel environments!A multi-agent collaborative environment where agents must press their respective colored buttons to complete the task.
# Default: 3 agents
env = grimax.make("CooperativeButtonPress-v1")
# Custom agent count
env = grimax.make("CooperativeButtonPress-v1", num_agents=2)
env = grimax.make("CooperativeButtonPress-v1", num_agents=4)Environment Details:
- Grid size: 10x10
- Observation: 7x7 egocentric view per agent
- Horizon: 200 steps
- Goal: All agents must stand on their color-matched buttons simultaneously
- Collision handling: Random tie-breaking (fair for all agents)
Override default parameters when creating environments:
env = grimax.make(
"CooperativeButtonPress-v1",
grid_size=12,
num_agents=5,
horizon=100,
view_size=5,
goal_reward=10.0,
step_penalty=-0.01,
collision_penalty=-0.1,
)Grimax includes a JAX-native sprite renderer:
from grimax import Renderer
# Create renderer from tileset
renderer = Renderer.from_tileset()
# Render state to RGB image
rgb_image = renderer.render(state) # (height, width, 3) uint8 arrayRender full trajectories as GIFs (requires grimax[gif]):
from grimax.rendering import render_trajectory
render_trajectory(
renderer,
states,
output_path="trajectory.gif",
fps=10,
)Automatically reset the environment when an episode ends:
from grimax.wrappers import AutoResetWrapper
env = grimax.make("CooperativeButtonPress-v1")
env = AutoResetWrapper(env)Make Grimax environments compatible with Gymnax/Memorax:
from grimax.wrappers import MultiAgentGymnaxWrapper
env = grimax.make("CooperativeButtonPress-v1")
env = MultiAgentGymnaxWrapper(env)Flatten observations for MLP-based policies:
from grimax.wrappers import FlattenMultiAgentObservationWrapper
env = grimax.make("CooperativeButtonPress-v1")
env = AutoResetWrapper(env)
env = MultiAgentGymnaxWrapper(env)
env = FlattenMultiAgentObservationWrapper(env)Grimax supports 4 discrete actions:
| Action | ID | Description |
|---|---|---|
MOVE_UP |
0 | Move one cell up |
MOVE_RIGHT |
1 | Move one cell right |
MOVE_DOWN |
2 | Move one cell down |
MOVE_LEFT |
3 | Move one cell left |
Observations are egocentric views of the grid, encoded as a 3-channel array:
- Channel 0: Entity type (wall, floor, agent, button, etc.)
- Channel 1: Entity color (8 colors available)
- Channel 2: Entity state (e.g., button pressed/unpressed)
Shape: (num_agents, view_size, view_size, 3)
grimax/
├── core/ # Core abstractions (Environment, State, Timestep)
├── envs/ # Multi-agent environments (CooperativeButtonPress)
├── grid/ # Grid utilities (encoding, dynamics, observations)
├── rendering/ # Sprite-based renderer
└── wrappers/ # Environment wrappers (AutoReset, Gymnax, etc.)
Check out the examples/ directory for complete usage examples:
ippo_cooperative_button_press.py— Full IPPO training pipeline with Memoraxplay.py— Interactive play with keyboard controlsrender_sprite.py— Sprite rendering demo
# Train IPPO agents on CooperativeButtonPress
python examples/ippo_cooperative_button_press.py
# Interactive play
python examples/play.py --env_id CooperativeButtonPress-v1# Run tests
pytest tests/
# Run with coverage
pytest tests/ --cov=grimaxContributions are welcome! Please feel free to submit a Pull Request.
- Fork the repository
- Create your feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
This project is licensed under the MIT License - see the LICENSE file for details.
- 0x72 for the beautiful DungeonTileset II sprites
- The JAX team for an amazing framework
- The Gymnax project for inspiration
Made for the MARL research community