@@ -200,14 +200,88 @@ def evaluate(cfg: DictConfig):
200200 solver .visualize ()
201201
202202
203+ def export (cfg : DictConfig ):
204+ from paddle import nn
205+ from paddle .static import InputSpec
206+
207+ # set model
208+ model = ppsci .arch .MLP (** cfg .MODEL )
209+ # initialize equation
210+ equation = {"VIV" : ppsci .equation .Vibration (2 , - 4 , 0 )}
211+ # initialize solver
212+ solver = ppsci .solver .Solver (
213+ model ,
214+ equation = equation ,
215+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
216+ )
217+ # Convert equation to func
218+ f_func = ppsci .lambdify (
219+ solver .equation ["VIV" ].equations ["f" ],
220+ solver .model ,
221+ list (solver .equation ["VIV" ].learnable_parameters ),
222+ )
223+
224+ class Wrapped_Model (nn .Layer ):
225+ def __init__ (self , model , func ):
226+ super ().__init__ ()
227+ self .model = model
228+ self .func = func
229+
230+ def forward (self , x ):
231+ model_out = self .model (x )
232+ func_out = self .func (x )
233+ return {** model_out , "f" : func_out }
234+
235+ solver .model = Wrapped_Model (model , f_func )
236+ # export models
237+ input_spec = [
238+ {key : InputSpec ([None , 1 ], "float32" , name = key ) for key in model .input_keys },
239+ ]
240+ solver .export (input_spec , cfg .INFER .export_path , skip_prune_program = True )
241+
242+
243+ def inference (cfg : DictConfig ):
244+ from deploy .python_infer import pinn_predictor
245+
246+ # set model predictor
247+ predictor = pinn_predictor .PINNPredictor (cfg )
248+
249+ infer_mat = ppsci .utils .reader .load_mat_file (
250+ cfg .VIV_DATA_PATH ,
251+ ("t_f" , "eta_gt" , "f_gt" ),
252+ alias_dict = {"eta_gt" : "eta" , "f_gt" : "f" },
253+ )
254+
255+ input_dict = {key : infer_mat [key ] for key in cfg .INFER .input_keys }
256+
257+ output_dict = predictor .predict (input_dict , cfg .INFER .batch_size )
258+
259+ # mapping data to cfg.INFER.output_keys
260+ output_dict = {
261+ store_key : output_dict [infer_key ]
262+ for store_key , infer_key in zip (cfg .INFER .output_keys , output_dict .keys ())
263+ }
264+ infer_mat .update (output_dict )
265+
266+ ppsci .visualize .plot .save_plot_from_1d_dict (
267+ "./viv_pred" , infer_mat , ("t_f" ,), ("eta" , "eta_gt" , "f" , "f_gt" )
268+ )
269+
270+
203271@hydra .main (version_base = None , config_path = "./conf" , config_name = "viv.yaml" )
204272def main (cfg : DictConfig ):
205273 if cfg .mode == "train" :
206274 train (cfg )
207275 elif cfg .mode == "eval" :
208276 evaluate (cfg )
277+ elif cfg .mode == "export" :
278+ export (cfg )
279+ elif cfg .mode == "infer" :
280+ inference (cfg )
209281 else :
210- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
282+ raise ValueError (
283+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
284+ )
211285
212286
213287if __name__ == "__main__" :
0 commit comments