Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/zh/examples/amgnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,46 @@
python amgnet_cylinder.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_cylinder_pretrained.pdparams
```

=== "模型导出命令"

=== "amgnet_airfoil"

``` sh
python amgnet_airfoil.py mode=export
```

=== "amgnet_cylinder"

``` sh
python amgnet_cylinder.py mode=export
```

=== "模型推理命令"

=== "amgnet_airfoil"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AMGNet/data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AMGNet/data.zip -o data.zip
# unzip it
unzip data.zip
python amgnet_airfoil.py mode=infer
```

=== "amgnet_cylinder"

``` sh
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AMGNet/data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AMGNet/data.zip -o data.zip
# unzip it
unzip data.zip
python amgnet_cylinder.py mode=infer
```

| 预训练模型 | 指标 |
|:--| :--|
| [amgnet_airfoil_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_airfoil_pretrained.pdparams) | loss(RMSE_validator): 0.0001 <br> RMSE.RMSE(RMSE_validator): 0.01315 |
Expand Down
79 changes: 79 additions & 0 deletions examples/amgnet/amgnet_airfoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,86 @@ def evaluate(cfg: DictConfig):
"airfoil",
)

def export(cfg: DictConfig):
# set model
model = ppsci.arch.AMGNet(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(model,cfg=cfg,)

# export
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None,2],"float32",name=key) for key in model.input_keys},
]
solver.export(input_spec,cfg.INFER.export_path)

def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
eval_dataloader_cfg = {
"dataset": {
"name": "MeshAirfoilDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_dir": cfg.EVAL_DATA_DIR,
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
},
"batch_size": cfg.INFER.batch_size,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
}
dataset = ppsci.data.dataset.MeshAirfoilDataset(**eval_dataloader_cfg["dataset"])

sample = dataset[0]
input_dict = {
"input": sample["input"].pos
}

output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to cfg.MODEL.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

velocity_x = output_dict["pred"][:, 0:1]
velocity_y = output_dict["pred"][:, 1:2]
pressure = output_dict["pred"][:, 2:3]

utils.log_images(
sample["input"].pos,
velocity_x,
sample["label"].y[:, 0:1],
dataset.elems_list,
0,
"airfoil_inference_x_velocity"
)

utils.log_images(
sample["input"].pos,
velocity_y,
sample["label"].y[:, 1:2],
dataset.elems_list,
1,
"airfoil_inference_y_velocity"
)

utils.log_images(
sample["input"].pos,
pressure,
sample["label"].y[:, 2:3],
dataset.elems_list,
2,
"airfoil_inference_pressure"
)

@hydra.main(version_base=None, config_path="./conf", config_name="amgnet_airfoil.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
Expand Down
79 changes: 79 additions & 0 deletions examples/amgnet/amgnet_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,86 @@ def evaluate(cfg: DictConfig):
"cylinder",
)

def export(cfg: DictConfig):
# set model
model = ppsci.arch.AMGNet(**cfg.MODEL)

# initialize solver
solver = ppsci.solver.Solver(model,cfg=cfg)

# export
from paddle.static import InputSpec

input_spec = [
{"input": InputSpec([None, 2], "float32", name="input")},
]
solver.export(input_spec, cfg.INFER.export_path)

def inference(cfg: DictConfig):
from deploy.python_infer import pinn_predictor

predictor = pinn_predictor.PINNPredictor(cfg)
eval_dataloader_cfg = {
"dataset": {
"name": "MeshCylinderDataset",
"input_keys": ("input",),
"label_keys": ("label",),
"data_dir": cfg.EVAL_DATA_DIR,
"mesh_graph_path": cfg.EVAL_MESH_GRAPH_PATH,
},
"batch_size": cfg.INFER.batch_size,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": False,
},
}
dataset = ppsci.data.dataset.MeshCylinderDataset(**eval_dataloader_cfg["dataset"])

sample = dataset[0]
input_dict = {
"input": sample["input"].pos
}

output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)

# mapping data to cfg.MODEL.output_keys
output_dict = {
store_key: output_dict[infer_key]
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
}

velocity_x = output_dict["pred"][:, 0:1]
velocity_y = output_dict["pred"][:, 1:2]
pressure = output_dict["pred"][:, 2:3]

utils.log_images(
sample["input"].pos,
velocity_x,
sample["label"].y[:, 0:1],
dataset.elems_list,
0,
"cylinder_inference_x_velocity"
)

utils.log_images(
sample["input"].pos,
velocity_y,
sample["label"].y[:, 1:2],
dataset.elems_list,
1,
"cylinder_inference_y_velocity"
)

utils.log_images(
sample["input"].pos,
pressure,
sample["label"].y[:, 2:3],
dataset.elems_list,
2,
"cylinder_inference_pressure"
)

@hydra.main(version_base=None, config_path="./conf", config_name="amgnet_cylinder.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
Expand Down
18 changes: 18 additions & 0 deletions examples/amgnet/conf/amgnet_airfoil.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ EVAL:
batch_size: 1
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_airfoil_pretrained.pdparams
export_path: ./inference/amgnet_airfoil
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
batch_size: 1024
18 changes: 18 additions & 0 deletions examples/amgnet/conf/amgnet_cylinder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ EVAL:
batch_size: 1
pretrained_model_path: null
eval_with_no_grad: true

# inference settings
INFER:
pretrained_model_path: https://paddle-org.bj.bcebos.com/paddlescience/models/amgnet/amgnet_cylinder_pretrained.pdparams
export_path: ./inference/amgnet_cylinder
pdmodel_path: ${INFER.export_path}.pdmodel
pdiparams_path: ${INFER.export_path}.pdiparams
onnx_path: ${INFER.export_path}.onnx
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 1024
num_cpu_threads: 10
batch_size: 1024
91 changes: 53 additions & 38 deletions examples/quick_start/case3.ipynb

Large diffs are not rendered by default.