@@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id):
602
602
problem_name = problem_instance .name
603
603
if problem_instance .was_reversed :
604
604
problem_name += "_rev"
605
- metrics = problem_instance .eval_metrics ( )
605
+ metrics = problem_instance .eval_metric_fns ( model_hparams )
606
606
if hasattr (model_hparams .problem , "task_list" ):
607
- metrics = model_hparams .problem .eval_metrics ()
608
- if not all ([m in METRICS_FNS for m in metrics ]):
609
- error_str = ("Unrecognized metric. Problem %s specified metrics "
610
- "%s. Recognized metrics are %s." )
611
- raise ValueError (error_str % (problem_name ,
612
- metrics ,
613
- list (METRICS_FNS .keys ())))
607
+ metrics = model_hparams .problem .eval_metric_fns (model_hparams )
614
608
615
609
tm = problem_instance .get_hparams (model_hparams ).modality ["targets" ]
616
610
if not isinstance (tm , dict ):
@@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id):
622
616
ptid = problem_instance .task_id # pylint: disable=cell-var-from-loop
623
617
weights_fn = weights_fn_for_mp (ptid )
624
618
625
- for metric in metrics :
626
- metric_fn = METRICS_FNS [metric ]
619
+ for metric , metric_fn in six .iteritems (metrics ):
627
620
overload_eval_metric_name = getattr (
628
621
model_hparams , "overload_eval_metric_name" , None )
629
622
if len (problems ) == 1 and overload_eval_metric_name :
@@ -642,9 +635,10 @@ def weights_fn_for_mp(problem_task_id):
642
635
643
636
def create_eager_metrics_for_problem (problem , model_hparams ):
644
637
"""See create_eager_metrics."""
645
- metric_names = problem .eval_metrics ( )
638
+ metric_fns = problem .eval_metric_fns ( model_hparams )
646
639
tm = problem .get_hparams (model_hparams ).modality ["targets" ]
647
- return create_eager_metrics (metric_names , weights_fn = tm .targets_weights_fn )
640
+ return create_eager_metrics_internal (
641
+ metric_fns , weights_fn = tm .targets_weights_fn )
648
642
649
643
650
644
def create_eager_metrics (metric_names , weights_fn = common_layers .weights_all ):
@@ -662,9 +656,26 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
662
656
"""
663
657
metric_fns = dict (
664
658
[(name , METRICS_FNS [name ]) for name in metric_names ])
659
+ return create_eager_metrics_internal (metric_fns , weights_fn )
660
+
661
+
662
+ def create_eager_metrics_internal (metric_fns ,
663
+ weights_fn = common_layers .weights_all ):
664
+ """Create metrics accumulators and averager for Eager mode.
665
+
666
+ Args:
667
+ metric_names: dict<metric name, metric function>
668
+ weights_fn: function that takes labels and returns a weights mask. Defaults
669
+ to weights of all 1, i.e. common_layers.weights_all. Use
670
+ common_layers.weights_nonzero if labels have 0-padding.
671
+
672
+ Returns:
673
+ (accum_fn(predictions, targets) => None,
674
+ result_fn() => dict<str metric_name, float avg_val>
675
+ """
665
676
tfe_metrics = dict ()
666
677
667
- for name in metric_names :
678
+ for name in metric_fns :
668
679
tfe_metrics [name ] = tfe .metrics .Mean (name = name )
669
680
670
681
def metric_accum (predictions , targets ):
@@ -675,7 +686,7 @@ def metric_accum(predictions, targets):
675
686
676
687
def metric_means ():
677
688
avgs = {}
678
- for name in metric_names :
689
+ for name in metric_fns :
679
690
avgs [name ] = tfe_metrics [name ].result ().numpy ()
680
691
return avgs
681
692
0 commit comments