@@ -275,9 +275,9 @@ def fun_and_jac(self, q: Array) -> tuple[Scalar, Array]:
275275
276276 def loss (self , x : Vector , params : Params ) -> tuple [Scalar , InverseLossAux ]:
277277 loss_surface : Scalar = self .loss_surface (x )
278- reg_mean : Scalar = 1e3 * self .regularize_mean (params )
279- reg_shear : Scalar = 1e3 * self .regularize_shear (params )
280- reg_volume : Scalar = 1e3 * self .regularize_volume (params )
278+ reg_mean : Scalar = self . reg_mean_weight * self .regularize_mean (params )
279+ reg_shear : Scalar = self . reg_shear_weight * self .regularize_shear (params )
280+ reg_volume : Scalar = self . reg_volume_weight * self .regularize_volume (params )
281281 # jax.debug.print("loss_surface = {}", loss_surface)
282282 # jax.debug.print("reg_mean = {}", reg_mean)
283283 # jax.debug.print("reg_shear = {}", reg_shear)
@@ -341,7 +341,7 @@ def regularize_volume(self, params: Params) -> Scalar:
341341 return regularization
342342
343343
344- def main (cfg : Config ) -> None :
344+ def main (cfg : Config ) -> None : # noqa: PLR0915
345345 mesh : pv .UnstructuredGrid = melon .load_unstructured_grid (cfg .input )
346346 target : pv .UnstructuredGrid = melon .load_unstructured_grid (cfg .target )
347347
@@ -395,16 +395,17 @@ def callback(intermediate_result: optim.Solution) -> None:
395395 # result: pv.UnstructuredGrid = mesh.warp_by_vector("solution") # pyright: ignore[reportAssignmentType]
396396 writer .append (mesh )
397397
398+ inverse .reg_mean_weight = 0.0
399+ inverse .reg_shear_weight = 0.0
400+ inverse .reg_volume_weight = 0.0
401+
398402 q_init : Float [Array , "ca 6" ] = sim_jax .rest_activation (inverse .n_active_cells )
399403 # q_init = q_init.at[:, :3].set(jnp.log(1.0))
400404 callback (optim .Solution ({"x" : q_init }))
401405 optimizer = optim .MinimizerScipy (
402406 jit = False , method = "L-BFGS-B" , tol = 1e-15 , options = {}
403407 )
404408 inverse .linear_solver = lx .CG (rtol = 1e-3 , atol = 1e-30 , max_steps = 10000 )
405- inverse .reg_mean_weight = 0.0
406- inverse .reg_shear_weight = 0.0
407- inverse .reg_volume_weight = 0.0
408409 solution : optim .Solution = optimizer .minimize (
409410 x0 = q_init , fun_and_jac = inverse .fun_and_jac , callback = callback
410411 )
0 commit comments