Skip to content

keraJLi/envelope

Repository files navigation

💌 Envelope: a JAX-native environment interface

# Create environments from JAX-native suites you have installed, ...
env = envelope.create("gymnax::CartPole-v1")

# ... interact with the environments using a simple interface, ...
state, info = env.init(key)
states, infos = jax.lax.scan(env.step, state, actions)
plt.plot(infos.reward.cumsum())

# ... and enjoy a powerful ecosystem of wrappers.
env = envelope.wrappers.AutoResetWrapper(env)
env = envelope.wrappers.VmapWrapper(env)
env = envelope.wrappers.ObservationNormalizationWrapper(env)

🌍 Simple, expressive interaction!

  • Environments are pytrees. Squish them through JAX transformations and trace their parameters.
  • Idiomatic jax-y interface of init(key: Key) -> State, Info and step(state: State, action: PyTree) -> State, Info. You can directly jax.scan over a step(...)!
  • Spaces are super simple. No Tuple, Dict nonsense! There are two spaces: Continuous and Discrete, which you can compose into a PyTreeSpace.
  • Explicit episode truncation supports correctly handling bootstrapping for value-function targets.
  • No auto-reset by default. Resetting every step can be expensive!

💪 Powerful, composable wrappers!

  • Carry state across episodes to track running statistics, for example to normalize observations.
  • Composable wrappers can be stacked in any order. For example, ObservationNormalizationWrapper before vs. after VmapWrapper gives per-env vs. global normalization.

🔌 Adapters for existing suites

📦 # 🤖 # 🌍
brax 🕺 12
craftax 🕺 4
gymnax 🕺 24
jumanji 🕺 / 👯 25 / 1
kinetix 🕺 4
mujoco_playground 🕺 54
navix 🕺 41
Total 🕺 / 👯 164 / 1
envelope.create("📦::🌍")

let's you create environments from any of the above!

📝 Testing

  • Default (no optional adapters deps required): uv run pytest -m "not adapters"
  • Adapters suite (requires full adapters dependency group):
    • uv sync --group adapters
    • uv run pytest -m adapters
    • If any adapter dependency is missing/broken, the run will fail fast with an error telling you what to install.

🏗️ Installation

pip install jax-envelope

💞 Related projects

  • stoa is a very similar project that provides adapters and wrappers for the jumanji-like interface.
  • Check out all the great suites we have adapters for! gymnax, brax, jumanji, kinetix, craftax, navix, mujoco_playground.
  • We will be adding support for jaxmarl and pgx in the future, as soon as we figured out the best ever MARL interface for JAX!

About

A jax-native environment interface

Resources

License

Stars

Watchers

Forks

Contributors

Languages