2020
2121import hydra
2222import numpy as np
23+ import paddle
2324from matplotlib import pyplot as plt
2425from omegaconf import DictConfig
2526
@@ -34,7 +35,9 @@ def split_tensors(
3435
3536 Args:
3637 tensors (List[np.array]): Non-empty tensor list.
37- ratio (float): Split ratio. For example, tensor list A is split to A1 and A2. len(A1) / len(A) = ratio.
38+ ratio (float): Split ratio. For example, tensor list A is split to A1 and A2.
39+ len(A1) / len(A) = ratio.
40+
3841 Returns:
3942 Tuple[List[np.array], List[np.array]]: Split tensors.
4043 """
@@ -192,10 +195,7 @@ def predict_and_save_plot(
192195 plt .colorbar (orientation = "horizontal" )
193196 plt .tight_layout ()
194197 plt .show ()
195- plt .savefig (
196- os .path .join (plot_dir , f"cfd_{ index } .png" ),
197- bbox_inches = "tight" ,
198- )
198+ plt .savefig (os .path .join (plot_dir , f"cfd_{ index } .png" ), bbox_inches = "tight" )
199199
200200
201201def train (cfg : DictConfig ):
@@ -376,17 +376,17 @@ def evaluate(cfg: DictConfig):
376376
377377 # define loss
378378 def loss_expr (
379- output_dict : Dict [str , np . ndarray ],
380- label_dict : Dict [str , np . ndarray ] = None ,
381- weight_dict : Dict [str , np . ndarray ] = None ,
382- ) -> float :
379+ output_dict : Dict [str , "paddle.Tensor" ],
380+ label_dict : Dict [str , "paddle.Tensor" ] = None ,
381+ weight_dict : Dict [str , "paddle.Tensor" ] = None ,
382+ ) -> Dict [ str , "paddle.Tensor" ] :
383383 output = output_dict ["output" ]
384384 y = label_dict ["output" ]
385385 loss_u = (output [:, 0 :1 , :, :] - y [:, 0 :1 , :, :]) ** 2
386386 loss_v = (output [:, 1 :2 , :, :] - y [:, 1 :2 , :, :]) ** 2
387387 loss_p = (output [:, 2 :3 , :, :] - y [:, 2 :3 , :, :]).abs ()
388388 loss = (loss_u + loss_v + loss_p ) / CHANNELS_WEIGHTS
389- return loss .sum ()
389+ return { "custom_loss" : loss .sum ()}
390390
391391 # manually build validator
392392 eval_dataloader_cfg = {
@@ -404,10 +404,10 @@ def loss_expr(
404404 }
405405
406406 def metric_expr (
407- output_dict : Dict [str , np . ndarray ],
408- label_dict : Dict [str , np . ndarray ] = None ,
409- weight_dict : Dict [str , np . ndarray ] = None ,
410- ) -> Dict [str , float ]:
407+ output_dict : Dict [str , "paddle.Tensor" ],
408+ label_dict : Dict [str , "paddle.Tensor" ] = None ,
409+ weight_dict : Dict [str , "paddle.Tensor" ] = None ,
410+ ) -> Dict [str , "paddle.Tensor" ]:
411411 output = output_dict ["output" ]
412412 y = label_dict ["output" ]
413413 total_mse = ((output - y ) ** 2 ).sum () / len (test_x )
0 commit comments