Skip to content

Commit

Permalink
restore mjx example
Browse files Browse the repository at this point in the history
  • Loading branch information
thowell committed Feb 26, 2024
1 parent 68619b1 commit 3434287
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
30 changes: 18 additions & 12 deletions python/mujoco_mpc/mjx/tasks/bimanual/handover.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
# limitations under the License.
# ==============================================================================

from etils import epath
import jax
from jax import numpy as jp
import mujoco
from mujoco import mjx
import os
import pathlib
from typing import Callable
from mujoco_mpc.mjx import predictive_sampling

CostFn = Callable[[mjx.Model, mjx.Data], jax.Array]

def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:
"""Returns cost for bimanual bring to target task."""
Expand Down Expand Up @@ -50,14 +48,22 @@ def bring_to_target(m: mjx.Model, d: mjx.Data) -> jax.Array:


def get_models_and_cost_fn() -> (
tuple[mujoco.MjModel, mujoco.MjModel, CostFn]
tuple[mujoco.MjModel, mujoco.MjModel, predictive_sampling.CostFn]
):
"""Returns a tuple of the model and the cost function."""
path = (
pathlib.Path(os.path.abspath("")).parent.parent.parent
/ "build/mjpc/tasks/bimanual/mjx_scene.xml"
)
sim_model = mujoco.MjModel.from_xml_path(str(path))
plan_model = mujoco.MjModel.from_xml_path(str(path))
path = epath.Path(
'build/mjpc/tasks/bimanual/'
)
model_file_name = 'mjx_scene.xml'
xml = (path / model_file_name).read_text()
assets = {}
for f in path.glob('*.xml'):
if f.name == model_file_name:
continue
assets[f.name] = f.read_bytes()
for f in (path / 'assets').glob('*'):
assets[f.name] = f.read_bytes()
sim_model = mujoco.MjModel.from_xml_string(xml, assets)
plan_model = mujoco.MjModel.from_xml_string(xml, assets)
plan_model.opt.timestep = 0.01 # incidentally, already the case
return sim_model, plan_model, bring_to_target
return sim_model, plan_model, bring_to_target
12 changes: 6 additions & 6 deletions python/mujoco_mpc/mjx/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# %%

import matplotlib.pyplot as plt
import mediapy
import mujoco
import predictive_sampling
from tasks.bimanual import handover
from mujoco_mpc.mjx import predictive_sampling
from mujoco_mpc.mjx.tasks.bimanual import handover
# %%
nsteps = 1
nsteps = 500
steps_per_plan = 4
frame_skip = 5 # how many steps between each rendered frame

Expand All @@ -34,7 +34,7 @@
nsample=128 - 1,
interp='zero',
)
# %%

trajectory, costs, plan_time = (
predictive_sampling.receding_horizon_optimization(
p,
Expand Down Expand Up @@ -70,4 +70,4 @@
frames.append(renderer.render())
# %%
mediapy.show_video(frames, fps=1/sim_model.opt.timestep/frame_skip)
# %%
# %%

0 comments on commit 3434287

Please sign in to comment.