Skip to content

Commit

Permalink
Refactor pybullet-related code into separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Oct 2, 2021
1 parent d2800b8 commit 29d1828
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 30 deletions.
30 changes: 0 additions & 30 deletions mbrl/util/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@
import gym
import gym.wrappers
import numpy as np
from numpy.lib.function_base import copy
import omegaconf
import torch

import mbrl.planning
import mbrl.types

from pybulletgym.envs.roboschool.robots.locomotors.walker_base import WalkerBase as RSWalkerBase
from pybulletgym.envs.mujoco.robots.locomotors.walker_base import WalkerBase as MJWalkerBase


def make_env(
cfg: Union[omegaconf.ListConfig, omegaconf.DictConfig],
Expand Down Expand Up @@ -215,9 +211,6 @@ def __init__(self, env: gym.wrappers.TimeLimit):
elif "mbrl.third_party.dmc2gym" in self._env.env.__class__.__module__:
self._enter_method = self._enter_dmcontrol
self._exit_method = self._exit_dmcontrol
elif "pybulletgym" in self._env.env.__class__.__module__:
self._enter_method = self._enter_pybullet
self._exit_method = self._exit_pybullet
else:
raise RuntimeError("Tried to freeze an unsupported environment.")

Expand All @@ -243,29 +236,6 @@ def _exit_dmcontrol(self):
self._env._elapsed_steps = self._elapsed_steps
self._env.env._env._step_count = self._step_count

def _enter_pybullet(self):
env = self._env.env
robot = env.robot
assert isinstance(robot, (RSWalkerBase, MJWalkerBase))
self.state_id = env._p.saveState()
self.ground_ids = env.ground_ids
self.potential = env.potential
self.reward = float(env.reward)
robot_keys = [("body_rpy", tuple), ("body_xyz", tuple), ("feet_contact", np.copy), ("initial_z", float), ("joint_speeds", np.copy), ("joints_at_limit", int), ("walk_target_dist", float), ("walk_target_theta", float), ("walk_target_x", float), ("walk_target_y", float)]

self.robot_data = {}
for k, t in robot_keys:
self.robot_data[k] = t(getattr(robot, k))

def _exit_pybullet(self):
env = self._env.env
env.ground_ids = self.ground_ids
env.potential = self.potential
env.reward = self.reward
env._p.restoreState(self.state_id)
for k, v in self.robot_data.items():
setattr(env.robot, k, v)

def __enter__(self):
return self._enter_method()

Expand Down
249 changes: 249 additions & 0 deletions mbrl/util/pybullet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union, cast

import gym
import gym.wrappers
import pybulletgym
import numpy as np
import omegaconf
import torch

import mbrl.planning
import mbrl.types

def make_env(
cfg: Union[omegaconf.ListConfig, omegaconf.DictConfig],
) -> Tuple[gym.Env, mbrl.types.TermFnType, Optional[mbrl.types.RewardFnType]]:
"""Creates an environment from a given OmegaConf configuration object.
This method expects the configuration, ``cfg``,
to have the following attributes (some are optional):
- ``cfg.overrides.env``: the string description of the environment.
Valid options are:
- "gym___<env_name>": a Pybullet Roboschool env name described here: https://github.com/benelot/pybullet-gym
- ``cfg.overrides.term_fn``: (only for dmcontrol and gym environments) a string
indicating the environment's termination function to use when simulating the
environment with the model. It should correspond to the name of a function in
:mod:`mbrl.env.termination_fns`.
- ``cfg.overrides.reward_fn``: (only for dmcontrol and gym environments)
a string indicating the environment's reward function to use when simulating the
environment with the model. If not present, it will try to use ``cfg.overrides.term_fn``.
If that's not present either, it will return a ``None`` reward function.
If provided, it should correspond to the name of a function in
:mod:`mbrl.env.reward_fns`.
- ``cfg.overrides.learned_rewards``: (optional) if present indicates that
the reward function will be learned, in which case the method will return
a ``None`` reward function.
- ``cfg.overrides.trial_length``: (optional) if presents indicates the maximum length
of trials. Defaults to 1000.
Args:
cfg (omegaconf.DictConf): the configuration to use.
Returns:
(tuple of env, termination function, reward function): returns the new environment,
the termination function to use, and the reward function to use (or ``None`` if
``cfg.learned_rewards == True``).
"""
if "gym___" in cfg.overrides.env:
import mbrl.env
env = gym.make(cfg.overrides.env.split("___")[1])
term_fn = getattr(mbrl.env.termination_fns, cfg.overrides.term_fn)
if hasattr(cfg.overrides, "reward_fn") and cfg.overrides.reward_fn is not None:
reward_fn = getattr(mbrl.env.reward_fns, cfg.overrides.reward_fn)
else:
reward_fn = getattr(mbrl.env.reward_fns, cfg.overrides.term_fn, None)
else:
raise ValueError("Invalid environment string.")

learned_rewards = cfg.overrides.get("learned_rewards", True)
if learned_rewards:
reward_fn = None

if cfg.seed is not None:
env.seed(cfg.seed)
env.observation_space.seed(cfg.seed + 1)
env.action_space.seed(cfg.seed + 2)

return env, term_fn, reward_fn

def make_env_from_str(env_name: str) -> gym.Env:
"""Creates a new environment from its string description.
Args:
env_name (str): the string description of the environment. Valid options are:
- "gym___<env_name>": a Gym environment (e.g., "gym___HalfCheetah-v2").
Returns:
(gym.Env): the created environment.
"""
if "gym___" in env_name:
env = gym.make(env_name.split("___")[1])
else:
raise ValueError("Invalid environment string.")

return env

def _is_pybullet_gym_env(env: gym.wrappers.TimeLimit) -> bool:
# I don't know how to do this yet
return True

class freeze_pybullet_env:
"""Provides a context to freeze a Pybullet environment.
This context allows the user to manipulate the state of a PyBullet environment and return it
to its original state upon exiting the context.
Works with pybulletgym environments
Example usage:
.. code-block:: python
env = gym.make("HalfCheetahPyBulletEnv-v0")
env.reset()
action = env.action_space.sample()
# o1_expected, *_ = env.step(action)
with freeze_pybullet_env(env):
step_the_env_a_bunch_of_times()
o1, *_ = env.step(action) # o1 will be equal to what o1_expected would have been
Args:
env (:class:`gym.wrappers.TimeLimit`): the environment to freeze.
"""

def __init__(self, env: gym.wrappers.TimeLimit):
self._env = env
self._init_state: np.ndarray = None
self._elapsed_steps = 0
self._step_count = 0

if _is_pybullet_gym_env(env):
self._enter_method = self._enter_pybullet_gym
self._exit_method = self._exit_pybullet_gym
else:
raise RuntimeError("Tried to freeze an unsupported environment.")

def _enter_pybullet(self):
# For now, the accepted envs are limited to ease implementation and testing
from pybulletgym.envs.roboschool.robots.locomotors.walker_base import WalkerBase as RSWalkerBase
from pybulletgym.envs.mujoco.robots.locomotors.walker_base import WalkerBase as MJWalkerBase
env = self._env.env
robot = env.robot
assert isinstance(robot, (RSWalkerBase, MJWalkerBase))
self.state_id = env._p.saveState()
self.ground_ids = env.ground_ids
self.potential = env.potential
self.reward = float(env.reward)
robot_keys = [("body_rpy", tuple), ("body_xyz", tuple), ("feet_contact", np.copy), ("initial_z", float), ("joint_speeds", np.copy), ("joints_at_limit", int), ("walk_target_dist", float), ("walk_target_theta", float), ("walk_target_x", float), ("walk_target_y", float)]

self.robot_data = {}
for k, t in robot_keys:
self.robot_data[k] = t(getattr(robot, k))

def _exit_pybullet(self):
env = self._env.env
env.ground_ids = self.ground_ids
env.potential = self.potential
env.reward = self.reward
env._p.restoreState(self.state_id)
for k, v in self.robot_data.items():
setattr(env.robot, k, v)

def __enter__(self):
return self._enter_method()

def __exit__(self, *_args):
return self._exit_method()

def get_current_state(env: gym.wrappers.TimeLimit) -> Tuple:
"""Returns the internal state of the environment.
Returns a tuple with information that can be passed to :func:set_env_state` to manually
set the environment (or a copy of it) to the same state it had when this function was called.
Works with pybulletgym environments
Args:
env (:class:`gym.wrappers.TimeLimit`): the environment.
"""
# TODO: Figure out what this should return
if _is_pybullet_gym_env(env):
return ()
else:
raise NotImplementedError(
"Only pybulletgym environments supported."
)


def set_env_state(state: Tuple, env: gym.wrappers.TimeLimit):
"""Sets the state of the environment.
Assumes ``state`` was generated using :func:`get_current_state`.
Works with pybulletgym environments
Args:
state (tuple): see :func:`get_current_state` for a description.
env (:class:`gym.wrappers.TimeLimit`): the environment.
"""
# TODO: Figure out what this should do
if _is_pybullet_gym_env(env):
pass
else:
raise NotImplementedError(
"Only pybulletgym environments supported."
)

def rollout_pybullet_env(
env: gym.wrappers.TimeLimit,
initial_obs: np.ndarray,
lookahead: int,
agent: Optional[mbrl.planning.Agent] = None,
plan: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Runs the environment for some number of steps then returns it to its original state.
Works with pybulletgym environments
Args:
env (:class:`gym.wrappers.TimeLimit`): the environment.
initial_obs (np.ndarray): the latest observation returned by the environment (only
needed when ``agent is not None``, to get the first action).
lookahead (int): the number of steps to run. If ``plan is not None``,
it is overridden by `len(plan)`.
agent (:class:`mbrl.planning.Agent`, optional): if given, an agent to obtain actions.
plan (sequence of np.ndarray, optional): if given, a sequence of actions to execute.
Takes precedence over ``agent`` when both are given.
Returns:
(tuple of np.ndarray): the observations, rewards, and actions observed, respectively.
"""
actions = []
real_obses = []
rewards = []
with freeze_pybullet_env(cast(gym.wrappers.TimeLimit, env)):
current_obs = initial_obs.copy()
real_obses.append(current_obs)
if plan is not None:
lookahead = len(plan)
for i in range(lookahead):
a = plan[i] if plan is not None else agent.act(current_obs)
if isinstance(a, torch.Tensor):
a = a.numpy()
next_obs, reward, done, _ = env.step(a)
actions.append(a)
real_obses.append(next_obs)
rewards.append(reward)
if done:
break
current_obs = next_obs
return np.stack(real_obses), np.stack(rewards), np.stack(actions)

0 comments on commit 29d1828

Please sign in to comment.