Skip to content

Commit 4684ce4

Browse files
committed
fix(inverse-grin): update regularization weights in loss calculation
1 parent ec62188 commit 4684ce4

File tree

4 files changed

+14
-13
lines changed

4 files changed

+14
-13
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:fddfec6f322a3cbec7a1b3a65cac5103ee2566fa440e1072ec41f9bba2016da8
3-
size 22592
2+
oid sha256:b9ba7de5157ef92dccc218fec5d5bafda378347dfb89969c406a10f69a8a74c8
3+
size 115
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:5f2d7537bc7844f6ff6e9c7c834a5c777b7367adaa52cad51c547fd2388b7068
2+
oid sha256:a9c078b86a0155a8b79d5fb1b91a2b53eb4453aaea188801b90c01ddb245d40f
33
size 4978337

exp/2025/09/24/inverse-grin/src/32-inverse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def fun_and_jac(self, q: Array) -> tuple[Scalar, Array]:
275275

276276
def loss(self, x: Vector, params: Params) -> tuple[Scalar, InverseLossAux]:
277277
loss_surface: Scalar = self.loss_surface(x)
278-
reg_mean: Scalar = 1e3 * self.regularize_mean(params)
279-
reg_shear: Scalar = 1e3 * self.regularize_shear(params)
280-
reg_volume: Scalar = 1e3 * self.regularize_volume(params)
278+
reg_mean: Scalar = self.reg_mean_weight * self.regularize_mean(params)
279+
reg_shear: Scalar = self.reg_shear_weight * self.regularize_shear(params)
280+
reg_volume: Scalar = self.reg_volume_weight * self.regularize_volume(params)
281281
# jax.debug.print("loss_surface = {}", loss_surface)
282282
# jax.debug.print("reg_mean = {}", reg_mean)
283283
# jax.debug.print("reg_shear = {}", reg_shear)

exp/2025/09/24/inverse-grin/src/33-inverse-no-reg.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def fun_and_jac(self, q: Array) -> tuple[Scalar, Array]:
275275

276276
def loss(self, x: Vector, params: Params) -> tuple[Scalar, InverseLossAux]:
277277
loss_surface: Scalar = self.loss_surface(x)
278-
reg_mean: Scalar = 1e3 * self.regularize_mean(params)
279-
reg_shear: Scalar = 1e3 * self.regularize_shear(params)
280-
reg_volume: Scalar = 1e3 * self.regularize_volume(params)
278+
reg_mean: Scalar = self.reg_mean_weight * self.regularize_mean(params)
279+
reg_shear: Scalar = self.reg_shear_weight * self.regularize_shear(params)
280+
reg_volume: Scalar = self.reg_volume_weight * self.regularize_volume(params)
281281
# jax.debug.print("loss_surface = {}", loss_surface)
282282
# jax.debug.print("reg_mean = {}", reg_mean)
283283
# jax.debug.print("reg_shear = {}", reg_shear)
@@ -341,7 +341,7 @@ def regularize_volume(self, params: Params) -> Scalar:
341341
return regularization
342342

343343

344-
def main(cfg: Config) -> None:
344+
def main(cfg: Config) -> None: # noqa: PLR0915
345345
mesh: pv.UnstructuredGrid = melon.load_unstructured_grid(cfg.input)
346346
target: pv.UnstructuredGrid = melon.load_unstructured_grid(cfg.target)
347347

@@ -395,16 +395,17 @@ def callback(intermediate_result: optim.Solution) -> None:
395395
# result: pv.UnstructuredGrid = mesh.warp_by_vector("solution") # pyright: ignore[reportAssignmentType]
396396
writer.append(mesh)
397397

398+
inverse.reg_mean_weight = 0.0
399+
inverse.reg_shear_weight = 0.0
400+
inverse.reg_volume_weight = 0.0
401+
398402
q_init: Float[Array, "ca 6"] = sim_jax.rest_activation(inverse.n_active_cells)
399403
# q_init = q_init.at[:, :3].set(jnp.log(1.0))
400404
callback(optim.Solution({"x": q_init}))
401405
optimizer = optim.MinimizerScipy(
402406
jit=False, method="L-BFGS-B", tol=1e-15, options={}
403407
)
404408
inverse.linear_solver = lx.CG(rtol=1e-3, atol=1e-30, max_steps=10000)
405-
inverse.reg_mean_weight = 0.0
406-
inverse.reg_shear_weight = 0.0
407-
inverse.reg_volume_weight = 0.0
408409
solution: optim.Solution = optimizer.minimize(
409410
x0=q_init, fun_and_jac=inverse.fun_and_jac, callback=callback
410411
)

0 commit comments

Comments
 (0)