@@ -283,8 +283,8 @@ def export(cfg: DictConfig):
283283
284284 input_spec = [
285285 {
286- key : InputSpec ([None , 16 , 128 ], "float32" , name = key )
287- for key in model . input_keys
286+ "states" : InputSpec ([1 , 255 , 3 , 64 , 128 ], "float32" , name = "states" ),
287+ "visc" : InputSpec ([ 1 , 1 ], "float32" , name = "visc" ),
288288 },
289289 ]
290290
@@ -309,6 +309,7 @@ def inference(cfg: DictConfig):
309309
310310 input_dict = {
311311 "states" : dataset .data [: cfg .VIS_DATA_NUMS , :- 1 ],
312+ "visc" : dataset .visc [: cfg .VIS_DATA_NUMS ],
312313 }
313314
314315 output_dict = predictor .predict (input_dict )
@@ -319,17 +320,19 @@ def inference(cfg: DictConfig):
319320 store_key : output_dict [infer_key ]
320321 for store_key , infer_key in zip (output_keys , output_dict .keys ())
321322 }
322-
323- input_dict = {
324- "states" : dataset .data [: cfg .VIS_DATA_NUMS , 1 :],
325- }
326-
327- data_dict = {** input_dict , ** output_dict }
328323 for i in range (cfg .VIS_DATA_NUMS ):
329- ppsci .visualize .save_plot_from_3d_dict (
324+ ppsci .visualize .plot . save_plot_from_2d_dict (
330325 f"./cylinder_transformer_pred_{ i } " ,
331- {key : value [i ] for key , value in data_dict .items ()},
332- ("states" , "pred_states" ),
326+ {
327+ "pred_ux" : output_dict ["pred_states" ][i ][:, 0 ],
328+ "pred_uy" : output_dict ["pred_states" ][i ][:, 1 ],
329+ "pred_p" : output_dict ["pred_states" ][i ][:, 2 ],
330+ },
331+ ("pred_ux" , "pred_uy" , "pred_p" ),
332+ 10 ,
333+ 20 ,
334+ np .linspace (- 2 , 14 , 9 ),
335+ np .linspace (- 4 , 4 , 5 ),
333336 )
334337
335338
0 commit comments