Skip to content

Commit

Permalink
[Fix]Fix arch.freeze/unfreeze and change eval of control_arm example (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 authored Nov 30, 2023
1 parent b1cc770 commit 7772ef9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/zh/examples/control_arm.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

| 预训练模型 | 指标 |
|:--| :--|
| [inverse_x_axis_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/control_arm/inverse_x_axis_pretrained.pdparams) | loss(geo_eval): 0.02505<br>MAE.lambda_(geo_eval): 0.01580<br>MAE.mu(geo_eval): 0.01984 |
| [inverse_x_axis_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/control_arm/inverse_x_axis_pretrained.pdparams) | loss(geo_eval): 0.02505<br>L2Rel.lambda_(geo_eval): 0.06025<br>L2Rel.mu(geo_eval): 0.07949 |

## 1. 背景简介

Expand Down
4 changes: 2 additions & 2 deletions examples/control_arm/inverse_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def train(cfg: DictConfig):
"batch_size": cfg.EVAL.batch_size.validator,
},
ppsci.loss.MSELoss("sum"),
metric={"MAE": ppsci.metric.MAE()},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="geo_eval",
)
validator = {geom_validator.name: geom_validator}
Expand Down Expand Up @@ -216,7 +216,7 @@ def evaluate(cfg: DictConfig):
"batch_size": cfg.EVAL.batch_size.validator,
},
ppsci.loss.MSELoss("sum"),
metric={"MAE": ppsci.metric.MAE()},
metric={"L2Rel": ppsci.metric.L2Rel()},
name="geo_eval",
)
validator = {geom_validator.name: geom_validator}
Expand Down
4 changes: 2 additions & 2 deletions ppsci/arch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def register_output_transform(

def freeze(self):
"""Freeze all parameters."""
for param in self.named_parameters():
for param in self.parameters():
param.stop_gradient = True

self.eval()

def unfreeze(self):
"""Unfreeze all parameters."""
for param in self.named_parameters():
for param in self.parameters():
param.stop_gradient = False

self.train()
Expand Down

0 comments on commit 7772ef9

Please sign in to comment.