Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/zh/examples/labelfree_DNN_surrogate.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
python aneurysm_flow.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/LabelFree-DNN-Surrogate/aneurysm_flow.pdparams
```


=== "模型导出命令"

``` sh
python poiseuille_flow.py mode=export
```


=== "模型推理命令"

``` sh
python poiseuille_flow.py mode=infer
```


| 预训练模型 | 指标 |
|:--| :--|
|[aneurysm_flow.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/LabelFree-DNN-Surrogate/aneurysm_flow.pdparams)| L-2 error u : 2.548e-4 <br> L-2 error v : 7.169e-5 |
Expand Down
18 changes: 18 additions & 0 deletions examples/pipe/conf/poiseuille_flow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,21 @@ TRAIN:
EVAL:
pretrained_model_path: null
eval_with_no_grad: true

# export settings
EXPORT:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/LabelFree-DNN-Surrogate/poiseuille_flow_pretrained.pdparams"

# inference settings
INFER:
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/LabelFree-DNN-Surrogate/poiseuille_flow_pretrained.pdparams"
device: gpu
engine: native
precision: fp32
ir_optim: true
min_subgraph_size: 5
gpu_mem: 2000
gpu_id: 0
max_batch_size: 8192
num_cpu_threads: 10
batch_size: 8192
249 changes: 248 additions & 1 deletion examples/pipe/poiseuille_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,261 @@ def forward(self, output_dict, label_dict):
plt.savefig(osp.join(PLOT_DIR, "pipe_unformUQ.png"), bbox_inches="tight")


def export(cfg):
from paddle.static import InputSpec

model_u = ppsci.arch.MLP(**cfg.MODEL.u_net)
model_v = ppsci.arch.MLP(**cfg.MODEL.v_net)
model_p = ppsci.arch.MLP(**cfg.MODEL.p_net)
Comment on lines +435 to +437
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 看起来缺少这部分代码代码,image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. 之前测试可以导出就没加这部分代码,已经修改好了。


solver_u = ppsci.solver.Solver(
model_u, pretrained_model_path=cfg.EXPORT.pretrained_model_path
)
solver_v = ppsci.solver.Solver(
model_v, pretrained_model_path=cfg.EXPORT.pretrained_model_path
)
solver_p = ppsci.solver.Solver(
model_p, pretrained_model_path=cfg.EXPORT.pretrained_model_path
)

input_spec_u = [
{
key: InputSpec([None, 1], "float32", name=key)
for key in cfg.MODEL.u_net.input_keys
},
]
input_spec_v = [
{
key: InputSpec([None, 1], "float32", name=key)
for key in cfg.MODEL.v_net.input_keys
},
]
input_spec_p = [
{
key: InputSpec([None, 1], "float32", name=key)
for key in cfg.MODEL.p_net.input_keys
},
]

export_path_u = os.path.join(cfg.output_dir, "u_net")
export_path_v = os.path.join(cfg.output_dir, "v_net")
export_path_p = os.path.join(cfg.output_dir, "p_net")

solver_u.export(input_spec_u, export_path_u)
solver_v.export(input_spec_v, export_path_v)
solver_p.export(input_spec_p, export_path_p)

print(f"Inference models have been exported to {cfg.output_dir}.")


def inference(cfg: DictConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference应该使用Predictor,加载export导出的模型进行计算,而不是使用solver

NU_MEAN = 0.001
NU_STD = 0.9
L = 1.0 # length of pipe
R = 0.05 # radius of pipe
RHO = 1 # density
P_OUT = 0 # pressure at the outlet of pipe
P_IN = 0.1 # pressure at the inlet of pipe
N_x = 10
N_y = 50
N_p = 50
X_IN = 0
X_OUT = X_IN + L
Y_START = -R
Y_END = Y_START + 2 * R
NU_START = NU_MEAN - NU_MEAN * NU_STD # 0.0001
NU_END = NU_MEAN + NU_MEAN * NU_STD # 0.1

data_1d_x = np.linspace(
X_IN, X_OUT, N_x, endpoint=True, dtype=paddle.get_default_dtype()
)
data_1d_y = np.linspace(
Y_START, Y_END, N_y, endpoint=True, dtype=paddle.get_default_dtype()
)
data_1d_nu = np.linspace(
NU_START, NU_END, N_p, endpoint=True, dtype=paddle.get_default_dtype()
)
data_2d_xy = (
np.array(np.meshgrid(data_1d_x, data_1d_y, data_1d_nu)).reshape(3, -1).T
)

model_u = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("u",), 3, 50, "swish")
model_v = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("v",), 3, 50, "swish")
model_p = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("p",), 3, 50, "swish")

class Transform:
def input_trans(self, input):
self.input = input
x, y = input["x"], input["y"]
nu = input["nu"]
b = 2 * np.pi / (X_OUT - X_IN)
c = np.pi * (X_IN + X_OUT) / (X_IN - X_OUT)
sin_x = X_IN * paddle.sin(b * x + c)
cos_x = X_IN * paddle.cos(b * x + c)
return {"sin(x)": sin_x, "cos(x)": cos_x, "y": y, "nu": nu}

def output_trans_u(self, input, out):
return {"u": out["u"] * (R**2 - self.input["y"] ** 2)}

def output_trans_v(self, input, out):
return {"v": (R**2 - self.input["y"] ** 2) * out["v"]}

def output_trans_p(self, input, out):
return {
"p": (
(P_IN - P_OUT) * (X_OUT - self.input["x"]) / L
+ (X_IN - self.input["x"]) * (X_OUT - self.input["x"]) * out["p"]
)
}

transform = Transform()
model_u.register_input_transform(transform.input_trans)
model_v.register_input_transform(transform.input_trans)
model_p.register_input_transform(transform.input_trans)
model_u.register_output_transform(transform.output_trans_u)
model_v.register_output_transform(transform.output_trans_v)
model_p.register_output_transform(transform.output_trans_p)
model = ppsci.arch.ModelList((model_u, model_v, model_p))

input_dict = {
"x": data_2d_xy[:, 0:1],
"y": data_2d_xy[:, 1:2],
"nu": data_2d_xy[:, 2:3],
}
u_analytical = np.zeros([N_y, N_x, N_p])
dP = P_IN - P_OUT

for i in range(N_p):
uy = (R**2 - data_1d_y**2) * dP / (2 * L * data_1d_nu[i] * RHO)
u_analytical[:, :, i] = np.tile(uy.reshape([N_y, 1]), N_x)

label_dict = {"u": np.ones_like(input_dict["x"])}
weight_dict = {"u": np.ones_like(input_dict["x"])}

dataset_vel = {
"name": "NamedArrayDataset",
"input": input_dict,
"label": label_dict,
"weight": weight_dict,
}
eval_cfg = {
"sampler": {
"name": "BatchSampler",
"shuffle": False,
"drop_last": False,
},
"batch_size": 2000,
}
eval_cfg["dataset"] = dataset_vel

solver = ppsci.solver.Solver(
model,
output_dir=cfg.output_dir,
pretrained_model_path=cfg.INFER.pretrained_model_path,
)

# inference
output_dict = solver.predict(input_dict, return_numpy=True)
u_pred = output_dict["u"].reshape(N_y, N_x, N_p)

PLOT_DIR = osp.join(cfg.output_dir, "visu")
os.makedirs(PLOT_DIR, exist_ok=True)

fontsize = 16
idx_X = int(round(N_x / 2)) # pipe velocity section at L/2
nu_index = [3, 6, 9, 12, 14, 20, 49] # pick 7 nu samples
ytext = [0.55, 0.5, 0.4, 0.28, 0.1, 0.05, 0.001] # text y position

plt.figure(1)
plt.clf()
for idxP in range(len(nu_index)):
ax1 = plt.subplot(111)
plt.plot(
data_1d_y,
u_analytical[:, idx_X, nu_index[idxP]],
color="darkblue",
linestyle="-",
lw=3.0,
alpha=1.0,
)
plt.plot(
data_1d_y,
u_pred[:, idx_X, nu_index[idxP]],
color="red",
linestyle="--",
dashes=(5, 5),
lw=2.0,
alpha=1.0,
)
plt.text(
-0.012,
ytext[idxP],
rf"$\nu = $ {data_1d_nu[nu_index[idxP]]:.2g}",
{"color": "k", "fontsize": fontsize - 4},
)

plt.ylabel(r"$u(y)$", fontsize=fontsize)
plt.xlabel(r"$y$", fontsize=fontsize)
ax1.tick_params(axis="x", labelsize=fontsize)
ax1.tick_params(axis="y", labelsize=fontsize)
ax1.set_xlim([-0.05, 0.05])
ax1.set_ylim([0.0, 0.62])
plt.savefig(osp.join(PLOT_DIR, "pipe_uProfiles.png"), bbox_inches="tight")

# Distribution of center velocity
# Predicted result
input_dict_test = {
"x": data_2d_xy[:, 0:1],
"y": data_2d_xy[:, 1:2],
"nu": data_2d_xy[:, 2:3],
}
output_dict_test = solver.predict(input_dict_test, return_numpy=True)
u_max_pred = output_dict_test["u"]

# Analytical result, y = 0
u_max_a = (R**2) * (P_IN - P_OUT) / (2 * L * data_1d_nu * RHO)

plt.figure(2)
plt.clf()
ax1 = plt.subplot(111)
sns.kdeplot(
u_max_a,
fill=True,
color="black",
label="Analytical",
linestyle="-",
linewidth=3,
)
sns.kdeplot(
u_max_pred,
fill=False,
color="red",
label="DNN",
linestyle="--",
linewidth=3.5,
)
plt.legend(prop={"size": fontsize})
plt.xlabel(r"$u_c$", fontsize=fontsize)
plt.ylabel(r"PDF", fontsize=fontsize)
ax1.tick_params(axis="x", labelsize=fontsize)
ax1.tick_params(axis="y", labelsize=fontsize)
plt.savefig(osp.join(PLOT_DIR, "pipe_unformUQ.png"), bbox_inches="tight")


@hydra.main(version_base=None, config_path="./conf", config_name="poiseuille_flow.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
elif cfg.mode == "export":
export(cfg)
elif cfg.mode == "infer":
inference(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
)


if __name__ == "__main__":
Expand Down