Skip to content

Commit

Permalink
Merge branch 'pretty_eval_output' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Mar 21, 2024
2 parents da865dd + 7930574 commit 66cded1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/darcy/darcy2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def poisson_ref_compute_func(_in):
cfg.NPOINT_PDE + cfg.NPOINT_BC, evenly=True
)
visualizer = {
"visualize_p": ppsci.visualize.VisualizerVtu(
"visualize_p_ux_uy": ppsci.visualize.VisualizerVtu(
vis_points,
{
"p": lambda d: d["p"],
Expand Down
18 changes: 15 additions & 3 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import numpy as np
import paddle
import paddle.distributed as dist
import prettytable
import sympy as sp
import visualdl as vdl
from packaging import version
Expand Down Expand Up @@ -545,10 +546,21 @@ def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
self.eval_func = ppsci.solver.eval.eval_func

result = self.eval_func(self, epoch_id, self.log_freq)
metric_msg = ", ".join(
[self.eval_output_info[key].avg_info for key in self.eval_output_info]
metric_table = prettytable.PrettyTable(
["Name", "Value"],
title=f"Evaluation Metric(s){'' if epoch_id == 0 else f' at epoch {epoch_id}'}",
align="l",
)
logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")

loss_msg = []
for name, value in self.eval_output_info.items():
if name.startswith("loss"):
loss_msg.append(f"{name}: {value.avg_fmt}")
else:
metric_table.add_row([name, value.avg_fmt])
loss_msg = ", ".join(loss_msg)

logger.info(f"[Eval][Epoch {epoch_id}][Avg] {loss_msg}\n{metric_table}")
self.eval_output_info.clear()

return result
Expand Down
6 changes: 6 additions & 0 deletions ppsci/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def avg_info(self):
self.avg = float(self.avg)
return f"{self.name}: {self.avg:.5f}"

@property
def avg_fmt(self):
if isinstance(self.avg, paddle.Tensor):
self.avg = float(self.avg)
return f"{self.avg:.5e}"

@property
def total(self):
return f"{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}"
Expand Down

0 comments on commit 66cded1

Please sign in to comment.