|
25 | 25 | from paddle import io |
26 | 26 |
|
27 | 27 | from ppsci.solver import printer |
| 28 | +from ppsci.utils import logger |
28 | 29 | from ppsci.utils import misc |
29 | 30 |
|
30 | 31 | if TYPE_CHECKING: |
@@ -167,10 +168,11 @@ def _eval_by_dataset( |
167 | 168 | for metric_name, metric_func in _validator.metric.items(): |
168 | 169 | # NOTE: compute metric with entire output and label |
169 | 170 | metric_dict = metric_func(all_output, all_label) |
170 | | - assert metric_name not in metric_dict_group, ( |
171 | | - f"Metric name({metric_name}) already exists, please ensure all metric " |
172 | | - "names are unique over all validators." |
173 | | - ) |
| 171 | + if metric_name in metric_dict_group: |
| 172 | + logger.warning( |
| 173 | + f"Metric name({metric_name}) already exists, please ensure " |
| 174 | + "all metric names are unique over all validators." |
| 175 | + ) |
174 | 176 | metric_dict_group[metric_name] = { |
175 | 177 | k: float(v) for k, v in metric_dict.items() |
176 | 178 | } |
@@ -254,10 +256,11 @@ def _eval_by_batch( |
254 | 256 |
|
255 | 257 | # collect batch metric |
256 | 258 | for metric_name, metric_func in _validator.metric.items(): |
257 | | - assert metric_name not in metric_dict_group, ( |
258 | | - f"Metric name({metric_name}) already exists, please ensure all metric " |
259 | | - "names are unique over all validators." |
260 | | - ) |
| 259 | + if metric_name in metric_dict_group: |
| 260 | + logger.warning( |
| 261 | + f"Metric name({metric_name}) already exists, please ensure " |
| 262 | + "all metric names are unique over all validators." |
| 263 | + ) |
261 | 264 | metric_dict_group[metric_name] = misc.Prettydefaultdict(list) |
262 | 265 | metric_dict = metric_func(output_dict, label_dict) |
263 | 266 | for var_name, metric_value in metric_dict.items(): |
|
0 commit comments