Skip to content

Commit

Permalink
fix export and infer (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiminzhang0830 authored May 30, 2024
1 parent 3729e14 commit 2fdee45
Showing 1 changed file with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def export(cfg: DictConfig):

input_spec = [
{
key: InputSpec([None, 16, 128], "float32", name=key)
for key in model.input_keys
"states": InputSpec([1, 255, 3, 64, 128], "float32", name="states"),
"visc": InputSpec([1, 1], "float32", name="visc"),
},
]

Expand All @@ -309,6 +309,7 @@ def inference(cfg: DictConfig):

input_dict = {
"states": dataset.data[: cfg.VIS_DATA_NUMS, :-1],
"visc": dataset.visc[: cfg.VIS_DATA_NUMS],
}

output_dict = predictor.predict(input_dict)
Expand All @@ -319,17 +320,19 @@ def inference(cfg: DictConfig):
store_key: output_dict[infer_key]
for store_key, infer_key in zip(output_keys, output_dict.keys())
}

input_dict = {
"states": dataset.data[: cfg.VIS_DATA_NUMS, 1:],
}

data_dict = {**input_dict, **output_dict}
for i in range(cfg.VIS_DATA_NUMS):
ppsci.visualize.save_plot_from_3d_dict(
ppsci.visualize.plot.save_plot_from_2d_dict(
f"./cylinder_transformer_pred_{i}",
{key: value[i] for key, value in data_dict.items()},
("states", "pred_states"),
{
"pred_ux": output_dict["pred_states"][i][:, 0],
"pred_uy": output_dict["pred_states"][i][:, 1],
"pred_p": output_dict["pred_states"][i][:, 2],
},
("pred_ux", "pred_uy", "pred_p"),
10,
20,
np.linspace(-2, 14, 9),
np.linspace(-4, 4, 5),
)


Expand Down

0 comments on commit 2fdee45

Please sign in to comment.