Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Custom evaluation metrics #1336

Merged
merged 3 commits into from
Jan 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,19 @@ def eval_metrics(self):
metrics.Metrics.ACC_PER_SEQ, metrics.Metrics.NEG_LOG_PERPLEXITY
]

def eval_metric_fns(self, model_hparams):
metric_names = self.eval_metrics()
if not all([m in metrics.METRICS_FNS for m in metric_names]):
error_str = ("Unrecognized metric. Problem %s specified metrics "
"%s. Recognized metrics are %s.")
raise ValueError(error_str % (self.name,
metric_names,
list(metrics.METRICS_FNS.keys())))
return {
metric_name: metrics.METRICS_FNS[metric_name]
for metric_name in metric_names
}

def eval_hooks(self, features, logits, hparams):
del features, logits, hparams
return []
Expand Down
39 changes: 25 additions & 14 deletions tensor2tensor/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,15 +602,9 @@ def weights_fn_for_mp(problem_task_id):
problem_name = problem_instance.name
if problem_instance.was_reversed:
problem_name += "_rev"
metrics = problem_instance.eval_metrics()
metrics = problem_instance.eval_metric_fns(model_hparams)
if hasattr(model_hparams.problem, "task_list"):
metrics = model_hparams.problem.eval_metrics()
if not all([m in METRICS_FNS for m in metrics]):
error_str = ("Unrecognized metric. Problem %s specified metrics "
"%s. Recognized metrics are %s.")
raise ValueError(error_str % (problem_name,
metrics,
list(METRICS_FNS.keys())))
metrics = model_hparams.problem.eval_metric_fns(model_hparams)

tm = problem_instance.get_hparams(model_hparams).modality["targets"]
if not isinstance(tm, dict):
Expand All @@ -622,8 +616,7 @@ def weights_fn_for_mp(problem_task_id):
ptid = problem_instance.task_id # pylint: disable=cell-var-from-loop
weights_fn = weights_fn_for_mp(ptid)

for metric in metrics:
metric_fn = METRICS_FNS[metric]
for metric, metric_fn in six.iteritems(metrics):
overload_eval_metric_name = getattr(
model_hparams, "overload_eval_metric_name", None)
if len(problems) == 1 and overload_eval_metric_name:
Expand All @@ -642,9 +635,10 @@ def weights_fn_for_mp(problem_task_id):

def create_eager_metrics_for_problem(problem, model_hparams):
"""See create_eager_metrics."""
metric_names = problem.eval_metrics()
metric_fns = problem.eval_metric_fns(model_hparams)
tm = problem.get_hparams(model_hparams).modality["targets"]
return create_eager_metrics(metric_names, weights_fn=tm.targets_weights_fn)
return create_eager_metrics_internal(
metric_fns, weights_fn=tm.targets_weights_fn)


def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
Expand All @@ -662,9 +656,26 @@ def create_eager_metrics(metric_names, weights_fn=common_layers.weights_all):
"""
metric_fns = dict(
[(name, METRICS_FNS[name]) for name in metric_names])
return create_eager_metrics_internal(metric_fns, weights_fn)


def create_eager_metrics_internal(metric_fns,
weights_fn=common_layers.weights_all):
"""Create metrics accumulators and averager for Eager mode.

Args:
metric_names: dict<metric name, metric function>
weights_fn: function that takes labels and returns a weights mask. Defaults
to weights of all 1, i.e. common_layers.weights_all. Use
common_layers.weights_nonzero if labels have 0-padding.

Returns:
(accum_fn(predictions, targets) => None,
result_fn() => dict<str metric_name, float avg_val>
"""
tfe_metrics = dict()

for name in metric_names:
for name in metric_fns:
tfe_metrics[name] = tfe.metrics.Mean(name=name)

def metric_accum(predictions, targets):
Expand All @@ -675,7 +686,7 @@ def metric_accum(predictions, targets):

def metric_means():
avgs = {}
for name in metric_names:
for name in metric_fns:
avgs[name] = tfe_metrics[name].result().numpy()
return avgs

Expand Down
10 changes: 5 additions & 5 deletions tensor2tensor/utils/t2t_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ def create_tpu_eval_metrics_fn(problem, model_hparams):
"""Create the metrics_fn that TPUEstimatorSpec expects."""

metric_fns = []
eval_metrics = problem.eval_metrics()
eval_metrics = problem.eval_metric_fns(model_hparams)

tm = _create_target_modality(problem.get_hparams(model_hparams).modality)
if isinstance(tm, dict):
Expand All @@ -1739,12 +1739,12 @@ def wrapped_metric_fn(logits, labels, features, weights_fn=weights_fn):

return wrapped_metric_fn

for metric in eval_metrics:
for metric, metric_fn in six.iteritems(eval_metrics):
if metric in TPU_METRIC_BLACKLIST:
log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric)
continue
name = "%s/metrics-%s/%s" % (k, problem.name, metric)
metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric])))
metric_fns.append((name, make_metric_fn(metric_fn)))
else:
weights_fn = tm.targets_weights_fn

Expand All @@ -1759,12 +1759,12 @@ def wrapped_metric_fn(logits, labels, features):

return wrapped_metric_fn

for metric in eval_metrics:
for metric, metric_fn in six.iteritems(eval_metrics):
if metric in TPU_METRIC_BLACKLIST:
log_warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric)
continue
name = "metrics-%s/%s" % (problem.name, metric)
metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric])))
metric_fns.append((name, make_metric_fn(metric_fn)))

def all_metrics_fn(**kwargs):
"""Construct metrics dictionary."""
Expand Down