Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit dbab44c

Browse files
ywkimafrozenator
authored andcommitted
Custom evaluation metrics (#1336)
* Custom evaluation metrics * Fix Python 2 compatibility issue * Fix notebook test
1 parent 0d23001 commit dbab44c

File tree

3 files changed

+43
-19
lines changed

3 files changed

+43
-19
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,19 @@ def eval_metrics(self):
367367
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY
368368
]
369369

370+
def eval_metric_fns(self, model_hparams):
371+
metric_names = self.eval_metrics()
372+
if not all([m in metrics.METRICS_FNS for m in metric_names]):
373+
error_str = ("Unrecognized metric. Problem %s specified metrics "
374+
"%s. Recognized metrics are %s.")
375+
raise ValueError(error_str % (self.name,
376+
metric_names,
377+
list(metrics.METRICS_FNS.keys())))
378+
return {
379+
metric_name: metrics.METRICS_FNS[metric_name]
380+
for metric_name in metric_names
381+
}
382+
370383
def eval_hooks(self, features, logits, hparams):
371384
del features, logits, hparams
372385
return []

tensor2tensor/utils/metrics.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id):
602602
problem_name = problem_instance.name
603603
if problem_instance.was_reversed:
604604
problem_name += "_rev"
605-
metrics = problem_instance.eval_metrics()
605+
metrics = problem_instance.eval_metric_fns(model_hparams)
606606
if hasattr(model_hparams.problem, "task_list"):
607-
metrics = model_hparams.problem.eval_metrics()
608-
if not all([m in METRICS_FNS for m in metrics]):
609-
error_str = ("Unrecognized metric. Problem %s specified metrics "
610-
"%s. Recognized metrics are %s.")
611-
raise ValueError(error_str % (problem_name,
612-
metrics,
613-
list(METRICS_FNS.keys())))
607+
metrics = model_hparams.problem.eval_metric_fns(model_hparams)
614608

615609
tm = problem_instance.get_hparams(model_hparams).modality["targets"]
616610
if not isinstance(tm, dict):
@@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id):
622616
ptid = problem_instance.task_id # pylint: disable=cell-var-from-loop
623617
weights_fn = weights_fn_for_mp(ptid)
624618

625-
for metric in metrics:
626-
metric_fn = METRICS_FNS[metric]
619+
for metric, metric_fn in six.iteritems(metrics):
627620
overload_eval_metric_name = getattr(
628621
model_hparams, "overload_eval_metric_name", None)
629622
if len(problems) == 1 and overload_eval_metric_name:
@@ -642,9 +635,10 @@ def weights_fn_for_mp(problem_task_id):
642635

643636
def create_eager_metrics_for_problem(problem, model_hparams):
644637
"""See create_eager_metrics."""
645-
metric_names = problem.eval_metrics()
638+
metric_fns = problem.eval_metric_fns(model_hparams)
646639
tm = problem.get_hparams(model_hparams).modality["targets"]
647-
return create_eager_metrics(metric_names, weights_fn=tm.targets_weights_fn)
640+
return create_eager_metrics_internal(
641+
metric_fns, weights_fn=tm.targets_weights_fn)
648642

649643

650644
def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
@@ -662,9 +656,26 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
662656
"""
663657
metric_fns = dict(
664658
[(name, METRICS_FNS[name]) for name in metric_names])
659+
return create_eager_metrics_internal(metric_fns, weights_fn)
660+
661+
662+
def create_eager_metrics_internal(metric_fns,
663+
weights_fn=common_layers.weights_all):
664+
"""Create metrics accumulators and averager for Eager mode.
665+
666+
Args:
667+
metric_names: dict<metric name, metric function>
668+
weights_fn: function that takes labels and returns a weights mask. Defaults
669+
to weights of all 1, i.e. common_layers.weights_all. Use
670+
common_layers.weights_nonzero if labels have 0-padding.
671+
672+
Returns:
673+
(accum_fn(predictions, targets) => None,
674+
result_fn() => dict<str metric_name, float avg_val>
675+
"""
665676
tfe_metrics = dict()
666677

667-
for name in metric_names:
678+
for name in metric_fns:
668679
tfe_metrics[name] = tfe.metrics.Mean(name=name)
669680

670681
def metric_accum(predictions, targets):
@@ -675,7 +686,7 @@ def metric_accum(predictions, targets):
675686

676687
def metric_means():
677688
avgs = {}
678-
for name in metric_names:
689+
for name in metric_fns:
679690
avgs[name] = tfe_metrics[name].result().numpy()
680691
return avgs
681692

tensor2tensor/utils/t2t_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,7 @@ def create_tpu_eval_metrics_fn(problem, model_hparams):
17211721
"""Create the metrics_fn that TPUEstimatorSpec expects."""
17221722

17231723
metric_fns = []
1724-
eval_metrics = problem.eval_metrics()
1724+
eval_metrics = problem.eval_metric_fns(model_hparams)
17251725

17261726
tm = _create_target_modality(problem.get_hparams(model_hparams).modality)
17271727
if isinstance(tm, dict):
@@ -1739,12 +1739,12 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn):
17391739

17401740
return wrapped_metric_fn
17411741

1742-
for metric in eval_metrics:
1742+
for metric, metric_fn in six.iteritems(eval_metrics):
17431743
if metric in TPU_METRIC_BLACKLIST:
17441744
log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric)
17451745
continue
17461746
name = "%s/metrics-%s/%s" % (k, problem.name, metric)
1747-
metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric])))
1747+
metric_fns.append((name, make_metric_fn(metric_fn)))
17481748
else:
17491749
weights_fn = tm.targets_weights_fn
17501750

@@ -1759,12 +1759,12 @@ def wrapped_metric_fn(logits, labels, features):
17591759

17601760
return wrapped_metric_fn
17611761

1762-
for metric in eval_metrics:
1762+
for metric, metric_fn in six.iteritems(eval_metrics):
17631763
if metric in TPU_METRIC_BLACKLIST:
17641764
log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric)
17651765
continue
17661766
name = "metrics-%s/%s" % (problem.name, metric)
1767-
metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric])))
1767+
metric_fns.append((name, make_metric_fn(metric_fn)))
17681768

17691769
def all_metrics_fn(**kwargs):
17701770
"""Construct metrics dictionary."""

0 commit comments

Comments
 (0)