Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thowell committed Feb 14, 2024
1 parent 6da5f8b commit 83b7daa
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions python/mujoco_mpc/demos/direct/direct_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,15 @@ def diff_cost_force(

if nparam > 0:
# gradient
grad[ndq:] += dfdp[t].T @ norm_grad
grad[ndq:] += (dfdp[t].T @ norm_grad).ravel()

# dense rows
dpdq012[:, (t - 1) * model.nv] += dfdp[t].T @ norm_hess @ dfdq012[t]

# set dense rows in Hessian
# TODO(taylor): confirm ravel
if nparam > 0:
hess[ndq:, :] = dpdq012.ravel()
hess[ndq:, :] = dpdq012.reshape((nparam, 3 * model.nv))

return grad, hess

Expand Down Expand Up @@ -768,15 +768,17 @@ def diff_cost_sensor(
# parameters
if nparam > 0:
# gradient
grad[ndq:] += dsdp[t][idx, :].T @ normi_grad
grad[ndq:] += (dsdp[t][idx, :].T @ normi_grad).ravel()

# dense row
dpdq012[:, (t - 1) * model.nv] += dsdp[t][idx, :].T @ normi_hess @ dsidq012
dpdq012[:, (t - 1) * model.nv] += (
dsdp[t][idx, :].T @ normi_hess @ dsidq012
)

# set Hessian dense rows
# TODO(taylor): confirm ravel
if nparam > 0:
hess[ndq:, :] = dpdq012.ravel()
hess[ndq:, :] = dpdq012.reshape((nparam, 3 * model.nv))

return grad, hess

Expand Down Expand Up @@ -911,7 +913,7 @@ class DirectOptimizer:
_hessian: band representation of cost Hessian wrt decision variables (nv * horizon x 3 * nv).
_hessian_factor: factorization of band represented cost Hessian
_search_direction: Gauss-Newton search direction (nv * horizon).
_qpos_candidate: candidate search point for configuration trajectory (nq x horizon).
_qpos_copy: copy of configuration trajectory (nq x horizon).
_regularization: current value for cost Hessian regularization.
_gradient_norm: normalized L2-norm of cost gradient.
_iterations_step: number of step iterations performed.
Expand Down Expand Up @@ -1025,7 +1027,7 @@ def __init__(
self._search_direction = np.zeros(self._ntotal)

# candidate qpos
self._qpos_candidate = np.zeros((model.nq, horizon))
self._qpos_copy = np.zeros((model.nq, horizon))

# regularization
self._regularization = 1.0e-12
Expand Down Expand Up @@ -1277,13 +1279,19 @@ def _cost_derivatives(
# parameters
if self._parameter_flag:
# gradient
self._gradient[self.model.nv * np.sum(self.pinned):] = self.weight_parameter * (self.parameter - self.parameter_target)
self._gradient[
self.model.nv * np.sum(self.pinned) :
] = self.weight_parameter * (self.parameter - self.parameter_target)

# Hessian
# TODO(taylor): improve
dense = np.zeros((self._num_parameter, self._ntotal_pin))
dense[:, self.model.nv * np.sum(self.pinned)] = self.weight_parameter * np.eye(self._num_parameter)
self._hessian[self.model.nv * np.sum(self.pinned):, :] = dense.ravel()
dense[
:, self.model.nv * np.sum(self.pinned)
] = self.weight_parameter * np.eye(self._num_parameter)
self._hessian[self.model.nv * np.sum(self.pinned) :, :] = dense.reshape(
(self._num_parameter, self._nband)
)

def _eval_search_direction(self) -> bool:
"""Compute search direction.
Expand Down Expand Up @@ -1424,6 +1432,11 @@ def optimize(self):
self._status_msg = "gradient tolerance achieved"
return

# copy variables
self._qpos_copy = np.copy(self.qpos)
if self._parameter_flag:
self._parameter_copy = np.copy(self.parameter)

# search iterations
candidate_cost = current_cost
self._improvement = 0.0
Expand All @@ -1441,27 +1454,31 @@ def optimize(self):
if not self._eval_search_direction():
return

# compute candidate
self._qpos_candidate = configuration_update(
# compute new variables
self.qpos = configuration_update(
self.model,
self.qpos,
self._qpos_copy,
self._search_direction,
-1.0,
self.horizon,
self.pinned,
)

if self._parameter_flag:
self.parameter = (
self._parameter_copy
- 1.0
* self._search_direction[(self.model.nv * np.sum(self.pinned)):]
)

# candidate cost
candidate_cost = self.cost(self._qpos_candidate)
candidate_cost = self.cost(self.qpos)
self._improvement = current_cost - candidate_cost

# check improvement
if candidate_cost < current_cost:
# update cost
current_cost = candidate_cost

# update configurations
self.qpos = np.array(self._qpos_candidate)
break
else:
# increase regularization
Expand Down

0 comments on commit 83b7daa

Please sign in to comment.