Skip to content

Commit

Permalink
Enable planning with PyBullet envs
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Oct 18, 2021
1 parent de2257a commit d6cc263
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions mbrl/diagnostics/control_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def init(env_name: str, seed: int):
handler__ = mbrl.util.create_handler_from_str(env_name)
env__ = handler__.make_env_from_str(env_name)
env__.seed(seed)
env__.reset()


def step_env(action: np.ndarray):
Expand Down
19 changes: 15 additions & 4 deletions mbrl/util/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
from typing import Tuple

import gym
Expand Down Expand Up @@ -112,6 +113,16 @@ def get_current_state(env: gym.wrappers.TimeLimit) -> Tuple:
else:
raise RuntimeError("Only pybulletgym environments supported.")

@staticmethod
def save_state_to_file(p) -> str:
bulletfile = tempfile.NamedTemporaryFile(suffix=".bullet").name
p.saveBullet(bulletfile)
return bulletfile

@staticmethod
def load_state_from_file(p, filename: str) -> None:
p.restoreState(fileName=filename)

@staticmethod
def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
"""Returns the internal state of the environment.
Expand All @@ -128,7 +139,7 @@ def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
if not isinstance(robot, (RSWalkerBase, MJWalkerBase)):
raise RuntimeError("Invalid robot type. Expected a locomotor robot")

state_id = env._p.saveState()
filename = PybulletEnvHandler.save_state_to_file(env._p)
ground_ids = env.ground_ids
potential = env.potential
reward = float(env.reward)
Expand All @@ -150,7 +161,7 @@ def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
robot_data[k] = t(getattr(robot, k))

return (
state_id,
filename,
ground_ids,
potential,
reward,
Expand Down Expand Up @@ -197,7 +208,7 @@ def _set_env_state_locomotion(state: Tuple, env: gym.wrappers.TimeLimit):
"""
if _is_pybullet_gym_env(env):
(
state_id,
filename,
ground_ids,
potential,
reward,
Expand All @@ -208,7 +219,7 @@ def _set_env_state_locomotion(state: Tuple, env: gym.wrappers.TimeLimit):
env.ground_ids = ground_ids
env.potential = potential
env.reward = reward
env._p.restoreState(state_id)
PybulletEnvHandler.load_state_from_file(env._p, filename)
for k, v in robot_data.items():
setattr(env.robot, k, v)
else:
Expand Down

0 comments on commit d6cc263

Please sign in to comment.