Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/zh/examples/adv_cvit.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ examples/adv/adv_cvit.py:117:125
--8<--
```

### 3.7 模型训练、评估
### 3.6 模型训练、评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估。

Expand Down
20 changes: 10 additions & 10 deletions docs/zh/examples/phylstm.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ examples/phylstm/phylstm2.py:37:100

设置训练数据集和损失计算函数,返回字段,代码如下所示:

``` py linenums="119"
``` py linenums="120"
--8<--
examples/phylstm/phylstm2.py:119:145
examples/phylstm/phylstm2.py:120:146
--8<--
```

### 3.4 评估器构建

设置评估数据集和损失计算函数,返回字段,代码如下所示:

``` py linenums="147"
``` py linenums="148"
--8<--
examples/phylstm/phylstm2.py:147:174
examples/phylstm/phylstm2.py:148:170
--8<--
```

Expand All @@ -136,27 +136,27 @@ examples/phylstm/conf/phylstm2.yaml:39:39

训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate` 为 1e-3。

``` py linenums="177"
``` py linenums="172"
--8<--
examples/phylstm/phylstm2.py:177:177
examples/phylstm/phylstm2.py:172:173
--8<--
```

### 3.7 模型训练与评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`。

``` py linenums="178"
``` py linenums="174"
--8<--
examples/phylstm/phylstm2.py:178:192
examples/phylstm/phylstm2.py:174:180
--8<--
```

最后启动训练、评估即可:

``` py linenums="194"
``` py linenums="182"
--8<--
examples/phylstm/phylstm2.py:194:197
examples/phylstm/phylstm2.py:182:185
--8<--
```

Expand Down
15 changes: 15 additions & 0 deletions examples/phylstm/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,19 @@ def get(self, epochs=1):
np.asarray(0.0, dtype=paddle.get_default_dtype())
)

def to_numpy_dict(dct):
return {k: np.asarray(v, dtype="float32") for k, v in dct.items()}

input_dict_train = to_numpy_dict(input_dict_train)
for k, v in input_dict_train.items():
print(f"input_dict_train {k} {type(v)}")
label_dict_train = to_numpy_dict(label_dict_train)
for k, v in label_dict_train.items():
print(f"label_dict_train {k} {type(v)}")
input_dict_val = to_numpy_dict(input_dict_val)
for k, v in input_dict_val.items():
print(f"input_dict_val {k} {type(v)}")
label_dict_val = to_numpy_dict(label_dict_val)
for k, v in label_dict_val.items():
print(f"label_dict_val {k} {type(v)}")
return input_dict_train, label_dict_train, input_dict_val, label_dict_val
29 changes: 5 additions & 24 deletions examples/phylstm/phylstm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def train(cfg: DictConfig):
model.register_output_transform(functions.transform_out)

dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t)

(
input_dict_train,
label_dict_train,
Expand Down Expand Up @@ -151,11 +152,6 @@ def train(cfg: DictConfig):
"input": input_dict_val,
"label": label_dict_val,
},
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
"batch_size": 1,
"num_workers": 0,
},
Expand All @@ -178,17 +174,9 @@ def train(cfg: DictConfig):
solver = ppsci.solver.Solver(
model,
constraint_pde,
cfg.output_dir,
optimizer,
None,
cfg.TRAIN.epochs,
cfg.TRAIN.iters_per_epoch,
save_freq=cfg.TRAIN.save_freq,
log_freq=cfg.log_freq,
seed=cfg.seed,
optimizer=optimizer,
validator=validator_pde,
checkpoint_path=cfg.TRAIN.checkpoint_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
cfg=cfg,
)

# train model
Expand Down Expand Up @@ -278,6 +266,7 @@ def evaluate(cfg: DictConfig):
model.register_output_transform(functions.transform_out)

dataset_obj = functions.Dataset(eta, eta_t, g, ag, ag_c, lift, phi_t)

(
_,
_,
Expand All @@ -292,11 +281,6 @@ def evaluate(cfg: DictConfig):
"input": input_dict_val,
"label": label_dict_val,
},
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
"batch_size": 1,
"num_workers": 0,
},
Expand All @@ -317,11 +301,8 @@ def evaluate(cfg: DictConfig):
# initialize solver
solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
seed=cfg.seed,
validator=validator_pde,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
cfg=cfg,
)
# evaluate
solver.eval()
Expand Down