Skip to content

Commit e1adb85

Browse files
add DeepONet export and infer
1 parent f17d1b3 commit e1adb85

File tree

3 files changed

+86
-60
lines changed

3 files changed

+86
-60
lines changed

docs/zh/examples/deeponet.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626
python deeponet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams
2727
```
2828

29+
=== "模型导出命令"
30+
31+
``` sh
32+
python deeponet.py mode=export
33+
```
34+
35+
=== "模型推理命令"
36+
37+
``` sh
38+
python deeponet.py mode=infer
39+
```
40+
2941
| 预训练模型 | 指标 |
3042
|:--| :--|
3143
| [deeponet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams) | loss(G_eval): 0.00003<br>L2Rel.G(G_eval): 0.01799 |

examples/operator_learning/conf/deeponet.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,21 @@ TRAIN:
6060
EVAL:
6161
pretrained_model_path: null
6262
eval_with_no_grad: true
63+
64+
# inference settings
65+
INFER:
66+
pretrained_model_path: "https://paddle-org.bj.bcebos.com/paddlescience/models/deeponet/deeponet_pretrained.pdparams"
67+
export_path: ./inference/deeponet
68+
pdmodel_path: ${INFER.export_path}.pdmodel
69+
pdiparams_path: ${INFER.export_path}.pdiparams
70+
device: gpu
71+
engine: native
72+
precision: fp32
73+
onnx_path: ${INFER.export_path}.onnx
74+
ir_optim: true
75+
min_subgraph_size: 10
76+
gpu_mem: 4000
77+
gpu_id: 0
78+
max_batch_size: 128
79+
num_cpu_threads: 4
80+
batch_size: 128

examples/operator_learning/deeponet.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -89,66 +89,10 @@ def train(cfg: DictConfig):
8989
# evaluate after finished training
9090
solver.eval()
9191

92-
# visualize prediction for different functions u and corresponding G(u)
93-
dtype = paddle.get_default_dtype()
94-
95-
def generate_y_u_G_ref(
96-
u_func: Callable, G_u_func: Callable
97-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
98-
"""Generate discretized data of given function u and corresponding G(u).
99-
100-
Args:
101-
u_func (Callable): Function u.
102-
G_u_func (Callable): Function G(u).
103-
104-
Returns:
105-
Tuple[np.ndarray, np.ndarray, np.ndarray]: Discretized data of u, y and G(u).
106-
"""
107-
x = np.linspace(0, 1, cfg.MODEL.num_loc, dtype=dtype).reshape(
108-
[1, cfg.MODEL.num_loc]
109-
)
110-
u = u_func(x)
111-
u = np.tile(u, [cfg.NUM_Y, 1])
112-
113-
y = np.linspace(0, 1, cfg.NUM_Y, dtype=dtype).reshape([cfg.NUM_Y, 1])
114-
G_ref = G_u_func(y)
115-
return u, y, G_ref
116-
117-
func_u_G_pair = [
118-
# (title_string, func_u, func_G(u)), s.t. dG/dx == u and G(u)(0) = 0
119-
(r"$u=\cos(x), G(u)=sin(x$)", lambda x: np.cos(x), lambda y: np.sin(y)), # 1
120-
(
121-
r"$u=sec^2(x), G(u)=tan(x$)",
122-
lambda x: (1 / np.cos(x)) ** 2,
123-
lambda y: np.tan(y),
124-
), # 2
125-
(
126-
r"$u=sec(x)tan(x), G(u)=sec(x) - 1$",
127-
lambda x: (1 / np.cos(x) * np.tan(x)),
128-
lambda y: 1 / np.cos(y) - 1,
129-
), # 3
130-
(
131-
r"$u=1.5^x\ln{1.5}, G(u)=1.5^x-1$",
132-
lambda x: 1.5**x * np.log(1.5),
133-
lambda y: 1.5**y - 1,
134-
), # 4
135-
(r"$u=3x^2, G(u)=x^3$", lambda x: 3 * x**2, lambda y: y**3), # 5
136-
(r"$u=4x^3, G(u)=x^4$", lambda x: 4 * x**3, lambda y: y**4), # 6
137-
(r"$u=5x^4, G(u)=x^5$", lambda x: 5 * x**4, lambda y: y**5), # 7
138-
(r"$u=6x^5, G(u)=x^6$", lambda x: 5 * x**4, lambda y: y**5), # 8
139-
(r"$u=e^x, G(u)=e^x-1$", lambda x: np.exp(x), lambda y: np.exp(y) - 1), # 9
140-
]
92+
def predict_func(input_dict):
93+
return solver.predict(input_dict, return_numpy=True)[cfg.MODEL.G_key]
14194

142-
os.makedirs(os.path.join(cfg.output_dir, "visual"), exist_ok=True)
143-
for i, (title, u_func, G_func) in enumerate(func_u_G_pair):
144-
u, y, G_ref = generate_y_u_G_ref(u_func, G_func)
145-
G_pred = solver.predict({"u": u, "y": y}, return_numpy=True)["G"]
146-
plt.plot(y, G_pred, label=r"$G(u)(y)_{ref}$")
147-
plt.plot(y, G_ref, label=r"$G(u)(y)_{pred}$")
148-
plt.legend()
149-
plt.title(title)
150-
plt.savefig(os.path.join(cfg.output_dir, "visual", f"func_{i}_result.png"))
151-
plt.clf()
95+
plot(cfg, predict_func)
15296

15397

15498
def evaluate(cfg: DictConfig):
@@ -189,6 +133,50 @@ def evaluate(cfg: DictConfig):
189133
)
190134
solver.eval()
191135

136+
def predict_func(input_dict):
137+
return solver.predict(input_dict, return_numpy=True)[cfg.MODEL.G_key]
138+
139+
plot(cfg, predict_func)
140+
141+
142+
def export(cfg: DictConfig):
143+
# set model
144+
model = ppsci.arch.DeepONet(**cfg.MODEL)
145+
146+
# initialize solver
147+
solver = ppsci.solver.Solver(
148+
model,
149+
pretrained_model_path=cfg.INFER.pretrained_model_path,
150+
)
151+
152+
# export model
153+
from paddle.static import InputSpec
154+
155+
input_spec = [
156+
{
157+
model.input_keys[0]: InputSpec(
158+
[None, 1000], "float32", name=model.input_keys[0]
159+
),
160+
model.input_keys[1]: InputSpec(
161+
[None, 1], "float32", name=model.input_keys[1]
162+
),
163+
}
164+
]
165+
solver.export(input_spec, cfg.INFER.export_path)
166+
167+
168+
def inference(cfg: DictConfig):
169+
from deploy import python_infer
170+
171+
predictor = python_infer.GeneralPredictor(cfg)
172+
173+
def predict_func(input_dict):
174+
return next(iter(predictor.predict(input_dict).values()))
175+
176+
plot(cfg, predict_func)
177+
178+
179+
def plot(cfg: DictConfig, predict_func: Callable):
192180
# visualize prediction for different functions u and corresponding G(u)
193181
dtype = paddle.get_default_dtype()
194182

@@ -242,13 +230,17 @@ def generate_y_u_G_ref(
242230
os.makedirs(os.path.join(cfg.output_dir, "visual"), exist_ok=True)
243231
for i, (title, u_func, G_func) in enumerate(func_u_G_pair):
244232
u, y, G_ref = generate_y_u_G_ref(u_func, G_func)
245-
G_pred = solver.predict({"u": u, "y": y}, return_numpy=True)["G"]
233+
G_pred = predict_func({"u": u, "y": y})
246234
plt.plot(y, G_pred, label=r"$G(u)(y)_{ref}$")
247235
plt.plot(y, G_ref, label=r"$G(u)(y)_{pred}$")
248236
plt.legend()
249237
plt.title(title)
250238
plt.savefig(os.path.join(cfg.output_dir, "visual", f"func_{i}_result.png"))
239+
logger.message(
240+
f"Saved result of function {i} to {cfg.output_dir}/visual/func_{i}_result.png"
241+
)
251242
plt.clf()
243+
plt.close()
252244

253245

254246
@hydra.main(version_base=None, config_path="./conf", config_name="deeponet.yaml")
@@ -257,6 +249,10 @@ def main(cfg: DictConfig):
257249
train(cfg)
258250
elif cfg.mode == "eval":
259251
evaluate(cfg)
252+
elif cfg.mode == "export":
253+
export(cfg)
254+
elif cfg.mode == "infer":
255+
inference(cfg)
260256
else:
261257
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
262258

0 commit comments

Comments
 (0)