Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/zh/examples/phygeonet.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@

```

=== "模型导出命令"

``` sh
# heat_equation
python heat_equation.py mode=export
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heat_equation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是heat_equation🥺

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是heat_equation🥺

哦这条评论忘记删了,没事儿


# heat_equation_bc
python heat_equation_with_bc.py mode=export
```

=== "模型推理命令"

``` sh
# heat_equation
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz -P ./data/

# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation.npz --create-dirs -o ./data/heat_equation.npz

python heat_equation.py mode=infer

# heat_equation_bc
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz -P ./data/
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz -P ./data/

# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc.npz --create-dirs -o ./data/heat_equation.npz
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyGeoNet/heat_equation_bc_test.npz --create-dirs -o ./data/heat_equation.npz

python heat_equation_with_bc.py mode=infer
```

| 模型 | mRes | ev |
| :-- | :-- | :-- |
| [heat_equation_pretrain.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams) | 0.815 |0.095|
Expand Down
18 changes: 18 additions & 0 deletions examples/phygeonet/conf/heat_equation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@ TRAIN:
EVAL:
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_pretrain.pdparams'
export_path: ./inference/heat_equation
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 20
gpu_id: 0
max_batch_size: 256
num_cpu_threads: 10
batch_size: 256
19 changes: 19 additions & 0 deletions examples/phygeonet/conf/heat_equation_with_bc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ hydra:

# general settings
mode: train # running mode: train/eval
log_freq: 50
seed: 66
data_dir: ./data/heat_equation_bc.npz
test_data_dir: ./data/heat_equation_bc_test.npz
Expand Down Expand Up @@ -53,3 +54,21 @@ TRAIN:
EVAL:
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: 'https://paddle-org.bj.bcebos.com/paddlescience/models/PhyGeoNet/heat_equation_bc_pretrain.pdparams'
export_path: ./inference/heat_equation_bc
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 20
gpu_id: 0
max_batch_size: 256
num_cpu_threads: 10
batch_size: 256
91 changes: 90 additions & 1 deletion examples/phygeonet/heat_equation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path as osp
from typing import Dict

import hydra
Expand Down Expand Up @@ -153,14 +154,102 @@ def evaluate(cfg: DictConfig):
plt.close(fig)


def export(cfg: DictConfig):
model = ppsci.arch.USCNN(**cfg.MODEL)
# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{
key: InputSpec([None, 2, 19, 84], "float32", name=key)
for key in model.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
data = np.load(cfg.data_dir)
coords = data["coords"]
ofv_sb = data["OFV_sb"]

## create model
pad_singleside = cfg.MODEL.pad_singleside
input_spec = {"coords": coords}

output_v = predictor.predict(input_spec, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_v = {
store_key: output_v[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys())
}

output_v = output_v["output_v"]

output_v[0, 0, -pad_singleside:, pad_singleside:-pad_singleside] = 0
output_v[0, 0, :pad_singleside, pad_singleside:-pad_singleside] = 1
output_v[0, 0, pad_singleside:-pad_singleside, -pad_singleside:] = 1
output_v[0, 0, pad_singleside:-pad_singleside, 0:pad_singleside] = 1
output_v[0, 0, 0, 0] = 0.5 * (output_v[0, 0, 0, 1] + output_v[0, 0, 1, 0])
output_v[0, 0, 0, -1] = 0.5 * (output_v[0, 0, 0, -2] + output_v[0, 0, 1, -1])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里麻烦补充上以下代码,和evaluate保持一致:

    ev = paddle.sqrt(
        paddle.mean((ofv_sb - output_v[0, 0]) ** 2) / paddle.mean(ofv_sb**2)
    ).item()
    logger.info(f"ev: {ev}")

ev = paddle.sqrt(
paddle.mean((ofv_sb - output_v[0, 0]) ** 2) / paddle.mean(ofv_sb**2)
).item()
logger.info(f"ev: {ev}")

fig = plt.figure()
ax = plt.subplot(1, 2, 1)
utils.visualize(
ax,
coords[0, 0, 1:-1, 1:-1],
coords[0, 1, 1:-1, 1:-1],
output_v[0, 0, 1:-1, 1:-1],
"horizontal",
[0, 1],
)
utils.set_axis_label(ax, "p")
ax.set_title("CNN " + r"$T$")
ax.set_aspect("equal")
ax = plt.subplot(1, 2, 2)
utils.visualize(
ax,
coords[0, 0, 1:-1, 1:-1],
coords[0, 1, 1:-1, 1:-1],
ofv_sb[1:-1, 1:-1],
"horizontal",
[0, 1],
)
utils.set_axis_label(ax, "p")
ax.set_aspect("equal")
ax.set_title("FV " + r"$T$")
fig.tight_layout(pad=1)
fig.savefig(osp.join(cfg.output_dir, "result.png"), bbox_inches="tight")
plt.close(fig)


@hydra.main(version_base=None, config_path="./conf", config_name="heat_equation.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down
121 changes: 120 additions & 1 deletion examples/phygeonet/heat_equation_with_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,119 @@ def evaluate(cfg: DictConfig):
plt.close(fig1)


def export(cfg: DictConfig):
model = ppsci.arch.USCNN(**cfg.MODEL)
# initialize solver
solver = ppsci.solver.Solver(
model,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)
# export model
from paddle.static import InputSpec

input_spec = [
{
key: InputSpec([None, 1, 19, 84], "float32", name=key)
for key in model.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
pad_singleside = cfg.MODEL.pad_singleside

data = np.load(cfg.test_data_dir)
paras = data["paras"]
truths = data["truths"]
coords = data["coords"]

paras = paras.reshape([paras.shape[0], 1, paras.shape[1], paras.shape[2]])
input_spec = {"coords": paras}
output_v = predictor.predict(input_spec, cfg.INFER.batch_size)
# mapping data to cfg.INFER.output_keys
output_v = {
store_key: output_v[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_v.keys())
}
output_v = output_v["output_v"]
num_sample = output_v.shape[0]
for j in range(num_sample):
# Impose BC
output_v[j, 0, -pad_singleside:, pad_singleside:-pad_singleside] = output_v[
j, 0, 1:2, pad_singleside:-pad_singleside
]
output_v[j, 0, :pad_singleside, pad_singleside:-pad_singleside] = output_v[
j, 0, -2:-1, pad_singleside:-pad_singleside
]
output_v[j, 0, :, -pad_singleside:] = 0
output_v[j, 0, :, 0:pad_singleside] = paras[j, 0, 0, 0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理,这里补充evaluate里的指标评估代码:

error = paddle.sqrt(
    paddle.mean((truths - output_v) ** 2) / paddle.mean(truths**2)
).item()
logger.info(f"The average error: {error / num_sample}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

同理,这里补充evaluate里的指标评估代码:

error = paddle.sqrt(
    paddle.mean((truths - output_v) ** 2) / paddle.mean(truths**2)
).item()
logger.info(f"The average error: {error / num_sample}")


error = paddle.sqrt(
paddle.mean((truths - output_v) ** 2) / paddle.mean(truths**2)
).item()
logger.info(f"The average error: {error / num_sample}")

output_vs = output_v
PARALIST = [1, 2, 3, 4, 5, 6, 7]
for i in range(len(PARALIST)):
truth = truths[i]
coord = coords[i]
output_v = output_vs[i]
truth = truth.reshape(1, 1, truth.shape[0], truth.shape[1])
coord = coord.reshape(1, 2, coord.shape[2], coord.shape[3])
fig1 = plt.figure()
xylabelsize = 20
xytickssize = 20
titlesize = 20
ax = plt.subplot(1, 2, 1)
_, cbar = utils.visualize(
ax,
coord[0, 0, :, :],
coord[0, 1, :, :],
output_v[0, :, :],
"horizontal",
[0, max(PARALIST)],
)
ax.set_aspect("equal")
utils.set_axis_label(ax, "p")
ax.set_title("PhyGeoNet " + r"$T$", fontsize=titlesize)
ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize)
ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize)
ax.set_xticks([-1, 0, 1])
ax.set_yticks([-1, 0, 1])
ax.tick_params(axis="x", labelsize=xytickssize)
ax.tick_params(axis="y", labelsize=xytickssize)
cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7])
cbar.ax.tick_params(labelsize=xytickssize)
ax = plt.subplot(1, 2, 2)
_, cbar = utils.visualize(
ax,
coord[0, 0, :, :],
coord[0, 1, :, :],
truth[0, 0, :, :],
"horizontal",
[0, max(PARALIST)],
)
ax.set_aspect("equal")
utils.set_axis_label(ax, "p")
ax.set_title("FV " + r"$T$", fontsize=titlesize)
ax.set_xlabel(xlabel=r"$x$", fontsize=xylabelsize)
ax.set_ylabel(ylabel=r"$y$", fontsize=xylabelsize)
ax.set_xticks([-1, 0, 1])
ax.set_yticks([-1, 0, 1])
ax.tick_params(axis="x", labelsize=xytickssize)
ax.tick_params(axis="y", labelsize=xytickssize)
cbar.set_ticks([0, 1, 2, 3, 4, 5, 6, 7])
cbar.ax.tick_params(labelsize=xytickssize)
fig1.tight_layout(pad=1)
fig1.savefig(osp.join(cfg.output_dir, f"Para{i}T.png"), bbox_inches="tight")
plt.close(fig1)


@hydra.main(
version_base=None, config_path="./conf", config_name="heat_equation_with_bc.yaml"
)
Expand All @@ -196,8 +309,14 @@ def main(cfg: DictConfig):
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down