Skip to content

Conversation

@kevinzakka
Copy link
Collaborator

@kevinzakka kevinzakka commented Dec 11, 2025

Adds set_const for recomputing model fields after runtime parameter changes.

  • set_const(m, d): full recomputation
  • set_const_fixed(m, d): fixed quantities only (body_subtreemass, ngravcomp)
  • set_const_0(m, d): qpos0-dependent quantities only

There are differences from the C mujoco implementation, namely that we skip fields that don't exist in mjwarp: dof_M0, actuator_length0, tree system fields (body_treeid, tree_*), stats (stat.center, stat.extent, stat.meansize).

Add 11 tests covering: body_subtreemass accumulation, ngravcomp counting, invweight computations (free/ball/hinge joints, static bodies), camera/light positions, actuator_acc0, qpos preservation, idempotency, and full pipeline validation against the C MuJoCo implementation using the humanoid model.

body_subtreemass_io[body_subtreemass_id, parentid] += body_subtreemass_io[body_subtreemass_id, bodyid]

wp.launch(init_subtreemass, dim=(d.nworld, m.nbody), inputs=[m.body_mass], outputs=[m.body_subtreemass])
for bodyid in range(m.nbody - 1, 0, -1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of iterating over bodies serially i think we can leverage tree-level parallelism. see com_pos for a design pattern that should be relevant here for reverse tree traversal with tree-level parallelism.

- dof_invweight0: inverse inertia for DOFs
- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera reference positions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'cam_mat0' is not a position, so maybe 'camera reference positions' -> 'camera references'?

- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera reference positions
- light_pos0, light_poscom0, light_dir0: light reference positions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'light_dir0' is not a position, so maybe 'light reference positions' -> 'light references'?

qpos0: wp.array2d(dtype=float),
qpos_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can launch this kernel over the configuration dimension nq instead of looping below

if jtype == int(types.JointType.FREE.value):
# FREE joint: 6 DOFs, average first 3 (trans) and last 3 (rot) separately
if dofid < dofadr + 3:
avg = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wp.static(1.0 / 3.0) * should be a bit more performant


body_invweight0_out[body_invweight0_id, bodyid] = wp.vec2(inv_trans, inv_rot)

for bodyid in range(1, m.nbody):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably do something more performant than nested iterations. will you please add a todo?

)
else:

@nested_kernel(module="unique", enable_backward=False)
Copy link
Collaborator

@thowell thowell Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this kernel just zeros all elements in an array we can utilize m.body_invweight0.zero_()

# body_invweight0: computed as mean diagonal of J * inv(M) * J'
# where J is the 6xnv body Jacobian (3 rows translation, 3 rows rotation)
if m.nv > 0:
body_jac_row = wp.zeros((d.nworld, m.nv), dtype=float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if all of the elements in these arrays are initialized by the kernels below, wp.empty should be a bit more performant compared to wp.zeros.

ten_J_in: wp.array3d(dtype=float),
ten_J_vec_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to launch this kernel over the dof dimension nv


# tendon_invweight0[t] = J_t * inv(M) * J_t'
if m.ntendon > 0:
ten_J_vec = wp.zeros((d.nworld, m.nv), dtype=float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if all the elements in these arrays are computed by the kernels below it should be a bit more performance to utilize wp.empty instead of wp.zeros

tendon_length0_id = worldid % tendon_length0_out.shape[0]
tendon_length0_out[tendon_length0_id, tenid] = ten_length_in[worldid, tenid]

if m.ntendon > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this check on ntendon since warp will not launch the kernel if ntendon == 0

cam_poscom0_out[cam_pos0_id, camid] = cam_xpos - subtree_com_in[worldid, bodyid]
cam_mat0_out[cam_pos0_id, camid] = cam_xmat_in[worldid, camid]

if m.ncam > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove the ncam check

light_poscom0_out[light_pos0_id, lightid] = light_xpos - subtree_com_in[worldid, bodyid]
light_dir0_out[light_pos0_id, lightid] = light_xdir_in[worldid, lightid]

if m.nlight > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove the nlight check


# actuator_acc0[i] = ||inv(M) * actuator_moment[i]|| - acceleration from unit actuator force
if m.nu > 0 and m.nv > 0:
act_moment_vec = wp.zeros((d.nworld, m.nv), dtype=float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if all of the array elements are initialized by kernels below we can utilize wp.empty

actuator_moment_in: wp.array3d(dtype=float),
act_moment_vec_out: wp.array2d(dtype=float),
):
worldid = wp.tid()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can launch over the dof dimension nv

eq_data | For connect/weld, offsets computed if not set.
hfield_size |
tendon_stiffness, tendon_damping | Only if changing from/to zero.
actuator_gainprm, actuator_biasprm | For position actuators with dampratio.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we align the vertical divider?

Modified Field | Call
------------------|------------------
body_mass | set_const (both)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does (both) mean?

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="link1" pos="0 0 0">
Copy link
Collaborator

@thowell thowell Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove pos="0 0 0" since this is the default

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets utilize test_data.fixture

mujoco.mj_setConst(mjm, mjd)
mjwarp.set_const(m, d)

np.testing.assert_allclose(m.dof_invweight0.numpy()[0], mjm.dof_invweight0, rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets utilize _assert_eq

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="link1" pos="0 0 0">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove pos="0 0 0"

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets utilize test_data.fixture

mjwarp.set_const(m, d)

np.testing.assert_allclose(m.dof_invweight0.numpy()[0], mjm.dof_invweight0, rtol=1e-3)
np.testing.assert_allclose(m.body_invweight0.numpy()[0, 1], mjm.body_invweight0[1], rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

mujoco.mj_setConst(mjm, mjd)
mjwarp.set_const(m, d)

np.testing.assert_allclose(m.dof_invweight0.numpy()[0], mjm.dof_invweight0, rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

mujoco.mj_setConst(mjm, mjd)
mjwarp.set_const(m, d)

np.testing.assert_allclose(m.body_invweight0.numpy()[0, 1], [0.0, 0.0], atol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="mass" pos="0 0 0">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove pos="0 0 0"

</mujoco>
""")
mjd = mujoco.MjData(mjm)
mjd.qpos[0] = 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add a keyframe and then utilize test_data.fixture

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="root" pos="0 0 0">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove pos="0 0 0"

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

_assert_eq(m.body_subtreemass.numpy()[0], mjm.body_subtreemass, "body_subtreemass")

# Verify: root=10+(20+30)+40=100, child1=20+30=50, grandchild1=30, child2=40
np.testing.assert_allclose(m.body_subtreemass.numpy()[0, 1], 100.0, rtol=1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="body1" pos="0 0 0" gravcomp="1">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove pos="0 0 0"

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

mujoco.mj_setConst(mjm, mjd)
mjwarp.set_const(m, d)

np.testing.assert_allclose(m.cam_pos0.numpy()[0, 0], mjm.cam_pos0[0], rtol=1e-5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

mjm = mujoco.MjModel.from_xml_string("""
<mujoco>
<worldbody>
<body name="link1" pos="0 0 0">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove pos="0 0 0"

mujoco.mj_forward(mjm, mjd)

m = mjwarp.put_model(mjm)
d = mjwarp.put_data(mjm, mjd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_data.fixture

qpos0_np[0, 7:11] = [0.9, 0.1, 0.1, 0.1]
qpos0_np[0, 11] = 0.5
qpos0_np[0, 12] = 0.3
wp.copy(m.qpos0, wp.array(qpos0_np, dtype=m.qpos0.dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add a keyframe and utilize test_data.fixture

mjwarp.set_const(m, d)

_assert_eq(m.body_subtreemass.numpy()[0], mjm.body_subtreemass, "body_subtreemass")
np.testing.assert_allclose(m.dof_invweight0.numpy()[0], mjm.dof_invweight0, rtol=1e-3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_assert_eq

Copy link
Collaborator

@thowell thowell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome contribution @kevinzakka!!

left some comments throughout the code. please let us know if you have any questions. thanks!

@thowell
Copy link
Collaborator

thowell commented Dec 12, 2025

it would be great to add a graph capture test. see here for an example.

@thowell
Copy link
Collaborator

thowell commented Jan 5, 2026

@kevinzakka some of the comments above are resolved by this pr

d: The data object containing the current state and output arrays (device).
"""

@nested_kernel(module="unique", enable_backward=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_subtreemass and accumulate_subtreemass should be moved outside of the set_const_fixed scope and decorated with @wp.kernel since they are purely functions of the inputs/outputs (eg, no dependency on Model or Data fields). also, lets prepend the function names with _.

we reserve utilizing @nested kernel to create kernels that utilize static variables (eg, dependency on Model or Data fields)

"""
qpos_saved = wp.clone(d.qpos)

@nested_kernel(module="unique", enable_backward=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the comments above we should move this kernel outside of the host function scope and change the decorator to @wp.kernel

smooth.factor_m(m, d)
smooth.transmission(m, d)

@nested_kernel(module="unique", enable_backward=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same things as copy_qpos0_qpos


@nested_kernel(module="unique", enable_backward=False)
def compute_body_jac_row(
nv: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this kernel, we can:

  1. move outside the host function scope and change the decorator
  2. remove nv as an argument and utilize wp.static(m.nv) in the kernel. if we go with this option, lets add a todo to make have a 'compute_body_jac_rowbuilder that takesbodyid_targetandrow_idx` as argument (these would be evaluated statically in the kernel)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants