Skip to content

Commit 5ea6f55

Browse files
committed
refactor(exp/inverse): Simplify activation parameter to a single scalar stretch
The parameter `q` for the inverse problem, representing muscle activation, is changed from a 6-DoF Voigt-notation vector to a single scalar per active cell. This scalar represents the principal stretch factor, gamma. The full activation matrix is now derived from this scalar, assuming volume preservation, with diagonal components `[1/gamma, sqrt(gamma), sqrt(gamma)]`. This significantly reduces the dimensionality of the optimization problem from `6 * n_active_cells` to `n_active_cells`, making it better-posed and easier to solve. Additionally, the maximum number of iterations for the conjugate gradient linear solver in the adjoint method is increased to improve the accuracy of the gradient calculation.
1 parent 4684ce4 commit 5ea6f55

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

exp/2025/09/24/inverse-grin/src/33-inverse-no-reg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def muscle_orientation(self) -> Float[Array, "c 3 3"]:
182182
def make_params(self, q: Float[Array, "ca 6"]) -> Params:
183183
activation: Float[Array, "c 6"] = sim_jax.rest_activation(self.input.n_cells)
184184
# q = q.at[:, :3].set(jnp.exp(q[:, :3]))
185-
activation = activation.at[self.active_mask].set(q)
185+
activation = activation.at[self.active_mask, 0].set(jnp.reciprocal(q))
186+
activation = activation.at[self.active_mask, 1].set(jnp.sqrt(q))
187+
activation = activation.at[self.active_mask, 2].set(jnp.sqrt(q))
186188
# activation = sim_jax.transform_activation(activation, self.muscle_orientation)
187189
return Params(activation=activation)
188190

@@ -240,7 +242,7 @@ def fun_and_jac(self, q: Array) -> tuple[Scalar, Array]:
240242
-dLdu_free,
241243
tol=1e-5,
242244
atol=1e-15,
243-
maxiter=ic(u_free.size),
245+
maxiter=ic(10 * u_free.size),
244246
M=lambda x: P_free * x,
245247
)
246248
logger.info("linear solve > info: {}", info)
@@ -395,12 +397,10 @@ def callback(intermediate_result: optim.Solution) -> None:
395397
# result: pv.UnstructuredGrid = mesh.warp_by_vector("solution") # pyright: ignore[reportAssignmentType]
396398
writer.append(mesh)
397399

400+
q_init: Float[Array, " ca"] = jnp.ones((inverse.n_active_cells,))
398401
inverse.reg_mean_weight = 0.0
399402
inverse.reg_shear_weight = 0.0
400403
inverse.reg_volume_weight = 0.0
401-
402-
q_init: Float[Array, "ca 6"] = sim_jax.rest_activation(inverse.n_active_cells)
403-
# q_init = q_init.at[:, :3].set(jnp.log(1.0))
404404
callback(optim.Solution({"x": q_init}))
405405
optimizer = optim.MinimizerScipy(
406406
jit=False, method="L-BFGS-B", tol=1e-15, options={}

0 commit comments

Comments
 (0)