From 2fdee459175b9729a88a2576d60adc42591f822b Mon Sep 17 00:00:00 2001 From: zzm <95690929+zhiminzhang0830@users.noreply.github.com> Date: Thu, 30 May 2024 21:05:35 +0800 Subject: [PATCH] fix export and infer (#916) --- .../transformer_physx/train_transformer.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/cylinder/2d_unsteady/transformer_physx/train_transformer.py b/examples/cylinder/2d_unsteady/transformer_physx/train_transformer.py index e703af1fd..34eb6c288 100644 --- a/examples/cylinder/2d_unsteady/transformer_physx/train_transformer.py +++ b/examples/cylinder/2d_unsteady/transformer_physx/train_transformer.py @@ -283,8 +283,8 @@ def export(cfg: DictConfig): input_spec = [ { - key: InputSpec([None, 16, 128], "float32", name=key) - for key in model.input_keys + "states": InputSpec([1, 255, 3, 64, 128], "float32", name="states"), + "visc": InputSpec([1, 1], "float32", name="visc"), }, ] @@ -309,6 +309,7 @@ def inference(cfg: DictConfig): input_dict = { "states": dataset.data[: cfg.VIS_DATA_NUMS, :-1], + "visc": dataset.visc[: cfg.VIS_DATA_NUMS], } output_dict = predictor.predict(input_dict) @@ -319,17 +320,19 @@ def inference(cfg: DictConfig): store_key: output_dict[infer_key] for store_key, infer_key in zip(output_keys, output_dict.keys()) } - - input_dict = { - "states": dataset.data[: cfg.VIS_DATA_NUMS, 1:], - } - - data_dict = {**input_dict, **output_dict} for i in range(cfg.VIS_DATA_NUMS): - ppsci.visualize.save_plot_from_3d_dict( + ppsci.visualize.plot.save_plot_from_2d_dict( f"./cylinder_transformer_pred_{i}", - {key: value[i] for key, value in data_dict.items()}, - ("states", "pred_states"), + { + "pred_ux": output_dict["pred_states"][i][:, 0], + "pred_uy": output_dict["pred_states"][i][:, 1], + "pred_p": output_dict["pred_states"][i][:, 2], + }, + ("pred_ux", "pred_uy", "pred_p"), + 10, + 20, + np.linspace(-2, 14, 9), + np.linspace(-4, 4, 5), )