Skip to content

noahfarr/grimax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Grimax

A JAX-native multi-agent grid world reinforcement learning library

Python 3.12+ JAX License: MIT


Features

  • JAX-Native — All operations are pure JAX, fully JIT-compilable for maximum performance
  • Multi-Agent First — Designed from the ground up for multi-agent reinforcement learning (MARL)
  • Functional API — Explicit state management with no hidden side effects
  • Vectorizable — Full vmap support for parallel environment rollouts
  • Fair Collision Resolution — Random tie-breaking for agent collisions (not index-priority)
  • Beautiful Rendering — Sprite-based renderer using the 0x72 DungeonTileset II

Installation

# Basic installation
pip install grimax

# With GIF rendering support
pip install "grimax[gif]"

# Development installation
pip install "grimax[dev]"

From Source

git clone https://github.com/noahfarr/grimax.git
cd grimax
pip install -e .

Quick Start

import jax
import grimax

# Create a multi-agent environment
env = grimax.make("CooperativeButtonPress-v1")

# Initialize
key = jax.random.PRNGKey(42)
state, timestep = env.init(key)

# Run a step with actions for all agents
key, step_key = jax.random.split(key)
actions = jax.numpy.array([0, 1, 2])  # One action per agent
state, timestep = env.step(step_key, state, actions)

# Access results
observation = timestep.observation  # (num_agents, view_size, view_size, 3)
reward = timestep.reward            # (num_agents,)
done = timestep.done                # boolean

JIT Compilation & Vectorization

Grimax is designed for high-performance MARL research. Both init and step are fully JIT-compilable:

import jax
import grimax

env = grimax.make("CooperativeButtonPress-v1")

# JIT compile for faster execution
@jax.jit
def run_episode(key):
    state, timestep = env.init(key)

    def step_fn(carry, _):
        state, key = carry
        key, step_key, action_key = jax.random.split(key, 3)
        # Random actions for all agents
        actions = jax.random.randint(action_key, (3,), 0, 4)
        state, timestep = env.step(step_key, state, actions)
        return (state, key), timestep.reward.sum()

    (final_state, _), rewards = jax.lax.scan(step_fn, (state, key), None, length=100)
    return rewards.sum()

# Vectorize across multiple environments
@jax.jit
@jax.vmap
def batch_init(keys):
    return env.init(keys)

keys = jax.random.split(jax.random.PRNGKey(0), 1024)
states, timesteps = batch_init(keys)  # 1024 parallel environments!

Environments

Cooperative Button Press

A multi-agent collaborative environment where agents must press their respective colored buttons to complete the task.

# Default: 3 agents
env = grimax.make("CooperativeButtonPress-v1")

# Custom agent count
env = grimax.make("CooperativeButtonPress-v1", num_agents=2)
env = grimax.make("CooperativeButtonPress-v1", num_agents=4)

Environment Details:

  • Grid size: 10x10
  • Observation: 7x7 egocentric view per agent
  • Horizon: 200 steps
  • Goal: All agents must stand on their color-matched buttons simultaneously
  • Collision handling: Random tie-breaking (fair for all agents)

Custom Parameters

Override default parameters when creating environments:

env = grimax.make(
    "CooperativeButtonPress-v1",
    grid_size=12,
    num_agents=5,
    horizon=100,
    view_size=5,
    goal_reward=10.0,
    step_penalty=-0.01,
    collision_penalty=-0.1,
)

Rendering

Grimax includes a JAX-native sprite renderer:

from grimax import Renderer

# Create renderer from tileset
renderer = Renderer.from_tileset()

# Render state to RGB image
rgb_image = renderer.render(state)  # (height, width, 3) uint8 array

Trajectory Rendering

Render full trajectories as GIFs (requires grimax[gif]):

from grimax.rendering import render_trajectory

render_trajectory(
    renderer,
    states,
    output_path="trajectory.gif",
    fps=10,
)

Wrappers

Auto Reset

Automatically reset the environment when an episode ends:

from grimax.wrappers import AutoResetWrapper

env = grimax.make("CooperativeButtonPress-v1")
env = AutoResetWrapper(env)

Gymnax Compatibility

Make Grimax environments compatible with Gymnax/Memorax:

from grimax.wrappers import MultiAgentGymnaxWrapper

env = grimax.make("CooperativeButtonPress-v1")
env = MultiAgentGymnaxWrapper(env)

Flatten Observations

Flatten observations for MLP-based policies:

from grimax.wrappers import FlattenMultiAgentObservationWrapper

env = grimax.make("CooperativeButtonPress-v1")
env = AutoResetWrapper(env)
env = MultiAgentGymnaxWrapper(env)
env = FlattenMultiAgentObservationWrapper(env)

Actions

Grimax supports 4 discrete actions:

Action ID Description
MOVE_UP 0 Move one cell up
MOVE_RIGHT 1 Move one cell right
MOVE_DOWN 2 Move one cell down
MOVE_LEFT 3 Move one cell left

Observations

Observations are egocentric views of the grid, encoded as a 3-channel array:

  • Channel 0: Entity type (wall, floor, agent, button, etc.)
  • Channel 1: Entity color (8 colors available)
  • Channel 2: Entity state (e.g., button pressed/unpressed)

Shape: (num_agents, view_size, view_size, 3)


Architecture

grimax/
├── core/           # Core abstractions (Environment, State, Timestep)
├── envs/           # Multi-agent environments (CooperativeButtonPress)
├── grid/           # Grid utilities (encoding, dynamics, observations)
├── rendering/      # Sprite-based renderer
└── wrappers/       # Environment wrappers (AutoReset, Gymnax, etc.)

Examples

Check out the examples/ directory for complete usage examples:

  • ippo_cooperative_button_press.py — Full IPPO training pipeline with Memorax
  • play.py — Interactive play with keyboard controls
  • render_sprite.py — Sprite rendering demo
# Train IPPO agents on CooperativeButtonPress
python examples/ippo_cooperative_button_press.py

# Interactive play
python examples/play.py --env_id CooperativeButtonPress-v1

Testing

# Run tests
pytest tests/

# Run with coverage
pytest tests/ --cov=grimax

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.


Acknowledgments

  • 0x72 for the beautiful DungeonTileset II sprites
  • The JAX team for an amazing framework
  • The Gymnax project for inspiration

Made for the MARL research community

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages