@@ -265,6 +265,86 @@ def evaluate(cfg: DictConfig):
265
265
solver .visualize ()
266
266
267
267
268
+ def export (cfg : DictConfig ):
269
+ from paddle .static import InputSpec
270
+
271
+ # set model
272
+ disp_net = ppsci .arch .MLP (** cfg .MODEL .disp_net )
273
+ stress_net = ppsci .arch .MLP (** cfg .MODEL .stress_net )
274
+ inverse_lambda_net = ppsci .arch .MLP (** cfg .MODEL .inverse_lambda_net )
275
+ inverse_mu_net = ppsci .arch .MLP (** cfg .MODEL .inverse_mu_net )
276
+ # wrap to a model_list
277
+ model = ppsci .arch .ModelList (
278
+ (disp_net , stress_net , inverse_lambda_net , inverse_mu_net )
279
+ )
280
+
281
+ # load pretrained model
282
+ solver = ppsci .solver .Solver (
283
+ model = model , pretrained_model_path = cfg .INFER .pretrained_model_path
284
+ )
285
+
286
+ # export models
287
+ input_spec = [
288
+ {
289
+ key : InputSpec ([None , 1 ], "float32" , name = key )
290
+ for key in cfg .MODEL .disp_net .input_keys
291
+ },
292
+ ]
293
+ solver .export (input_spec , cfg .INFER .export_path )
294
+
295
+
296
+ def inference (cfg : DictConfig ):
297
+ from deploy .python_infer import pinn_predictor
298
+ from ppsci .visualize import vtu
299
+
300
+ # set model predictor
301
+ predictor = pinn_predictor .PINNPredictor (cfg )
302
+
303
+ # set geometry
304
+ control_arm = ppsci .geometry .Mesh (cfg .GEOM_PATH )
305
+ # geometry bool operation
306
+ geo = control_arm
307
+ geom = {"geo" : geo }
308
+ # set bounds
309
+ BOUNDS_X , BOUNDS_Y , BOUNDS_Z = control_arm .bounds
310
+ samples = geom ["geo" ].sample_interior (
311
+ cfg .EVAL .batch_size .visualizer_vtu ,
312
+ criteria = lambda x , y , z : (
313
+ (BOUNDS_X [0 ] < x )
314
+ & (x < BOUNDS_X [1 ])
315
+ & (BOUNDS_Y [0 ] < y )
316
+ & (y < BOUNDS_Y [1 ])
317
+ & (BOUNDS_Z [0 ] < z )
318
+ & (z < BOUNDS_Z [1 ])
319
+ ),
320
+ )
321
+ pred_input_dict = {
322
+ k : v for k , v in samples .items () if k in cfg .MODEL .disp_net .input_keys
323
+ }
324
+
325
+ output_dict = predictor .predict (pred_input_dict , cfg .INFER .batch_size )
326
+
327
+ # mapping data to output_keys
328
+ output_keys = (
329
+ cfg .MODEL .disp_net .output_keys
330
+ + cfg .MODEL .stress_net .output_keys
331
+ + cfg .MODEL .inverse_lambda_net .output_keys
332
+ + cfg .MODEL .inverse_mu_net .output_keys
333
+ )
334
+ output_dict = {
335
+ store_key : output_dict [infer_key ]
336
+ for store_key , infer_key in zip (output_keys , output_dict .keys ())
337
+ }
338
+ output_dict .update (pred_input_dict )
339
+ vtu .save_vtu_from_dict (
340
+ osp .join (cfg .output_dir , "vis" ),
341
+ output_dict ,
342
+ cfg .MODEL .disp_net .input_keys ,
343
+ output_keys ,
344
+ 1 ,
345
+ )
346
+
347
+
268
348
@hydra .main (
269
349
version_base = None , config_path = "./conf" , config_name = "inverse_parameter.yaml"
270
350
)
@@ -273,8 +353,14 @@ def main(cfg: DictConfig):
273
353
train (cfg )
274
354
elif cfg .mode == "eval" :
275
355
evaluate (cfg )
356
+ elif cfg .mode == "export" :
357
+ export (cfg )
358
+ elif cfg .mode == "infer" :
359
+ inference (cfg )
276
360
else :
277
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
361
+ raise ValueError (
362
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
363
+ )
278
364
279
365
280
366
if __name__ == "__main__" :
0 commit comments