Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit c640bd0

Browse files
HyperparticleMesh TensorFlow Team
authored and
Mesh TensorFlow Team
committed
Save scores to multiple shards.
PiperOrigin-RevId: 393882034
1 parent 739fb09 commit c640bd0

File tree

1 file changed

+135
-8
lines changed

1 file changed

+135
-8
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626

2727
import functools
2828
import itertools
29+
import math
2930
import os
3031
import random
3132
import re
33+
import time
3234

3335
import gin
3436
import gin.tf
@@ -38,6 +40,7 @@
3840
from mesh_tensorflow.transformer import learning_rate_schedules
3941
from mesh_tensorflow.transformer import transformer
4042
import numpy as np
43+
import pandas as pd
4144
import pkg_resources
4245
import six
4346
import tensorflow.compat.v1 as tf
@@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0):
16541657
return scores
16551658

16561659

1660+
@gin.configurable
1661+
def save_scores_to_tfrecords(
1662+
results, vocabulary, scores_filename, shard_idx=0, save_ids_only=False):
1663+
"""Processes results from scoring examples and saves them to tfrecords files.
1664+
1665+
Args:
1666+
results: list of dictionaries containing the results for each scored
1667+
example.
1668+
vocabulary: a function that that returns a tf.data.Dataset with examples
1669+
containing the string field 'targets' and optionally the field 'inputs'
1670+
scores_filename: a string (path of file to write scores to).
1671+
shard_idx: an integer indicating the current index of the file for sharding.
1672+
save_ids_only: if true, save the ID that is prepended to the inputs.
1673+
"""
1674+
results = _maybe_add_pretokenized_features(results, vocabulary)
1675+
scores = [r.get("scores", 0.0) for r in results]
1676+
targets = [r.get("targets_pretokenized", r["targets"]) for r in results]
1677+
inputs = [r.get("targets_neg_pretokenized", "") for r in results]
1678+
1679+
if save_ids_only:
1680+
inputs = [r.split(" ", 1)[0] for r in inputs]
1681+
1682+
table_path = "{}_{}.tfrecord".format(scores_filename, shard_idx)
1683+
tf.logging.info("Saving results to {}".format(table_path))
1684+
1685+
with tf.io.TFRecordWriter(table_path) as file_writer:
1686+
for input_, target, score in zip(inputs, targets, scores):
1687+
record_bytes = tf.train.Example(
1688+
features=tf.train.Features(
1689+
feature={
1690+
"input":
1691+
tf.train.Feature(
1692+
bytes_list=tf.train.BytesList(
1693+
value=[bytes(input_, "utf8")])),
1694+
"target":
1695+
tf.train.Feature(
1696+
bytes_list=tf.train.BytesList(
1697+
value=[bytes(target, "utf8")])),
1698+
"score":
1699+
tf.train.Feature(
1700+
float_list=tf.train.FloatList(value=[score])),
1701+
})).SerializeToString()
1702+
file_writer.write(record_bytes)
1703+
1704+
1705+
@gin.configurable
16571706
def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16581707
vocabulary, score_postprocess_fn=save_scores,
16591708
num_examples=None):
@@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
16911740
return score_postprocess_fn(results, vocabulary)
16921741

16931742

1743+
@gin.configurable
1744+
def score_with_estimator_lazy(
1745+
estimator, input_fn, eval_checkpoint_step, model_dir,
1746+
vocabulary, score_postprocess_fn=save_scores_to_tfrecords,
1747+
num_examples=None, num_examples_per_shard=10000):
1748+
"""Score each example returned by input_fn lazily.
1749+
1750+
Args:
1751+
estimator: a TPUEstimator
1752+
input_fn: a function that that returns a tf.data.Dataset with examples
1753+
containing the string field 'targets' and optionally the field 'inputs'
1754+
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1755+
docstring.
1756+
model_dir: string, estimator model_dir
1757+
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1758+
targets_vocabulary) tuple
1759+
score_postprocess_fn: a function that takes in model outputs and
1760+
post-processes, saves, and returns them.
1761+
num_examples: int, the total # of examples being scored, None if unknown
1762+
num_examples_per_shard: int, the number of examples per file shard.
1763+
1764+
Returns:
1765+
a list of floats
1766+
"""
1767+
if num_examples is not None:
1768+
num_shards = math.ceil(num_examples / num_examples_per_shard)
1769+
else:
1770+
num_shards = None
1771+
tf.logging.info(
1772+
"Scoring {} examples with {} shards at {} examples per shard".format(
1773+
num_examples, num_shards, num_examples_per_shard))
1774+
1775+
checkpoint_path, = get_checkpoint_iterator(
1776+
eval_checkpoint_step, model_dir)
1777+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
1778+
1779+
start = time.time()
1780+
results = []
1781+
shard_idx = 0
1782+
1783+
for i, result in enumerate(result_iter):
1784+
results.append(result)
1785+
num_results = len(results)
1786+
exceeded_num_examples = num_examples is not None and i >= num_examples
1787+
1788+
if num_results >= num_examples_per_shard or exceeded_num_examples:
1789+
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)
1790+
1791+
elapsed = time.time() - start
1792+
tf.logging.info(
1793+
"Scored {} results in {} s, {} examples/s for shard {}".format(
1794+
num_results, elapsed, num_results / elapsed, shard_idx))
1795+
1796+
results = []
1797+
shard_idx += 1
1798+
start = time.time()
1799+
1800+
if exceeded_num_examples:
1801+
break
1802+
1803+
if results:
1804+
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)
1805+
1806+
16941807
def _maybe_add_pretokenized_features(examples, vocabulary):
16951808
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
16961809
@@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
17121825
for example in examples:
17131826
for feature_name in ["inputs", "targets"]:
17141827
pretokenized_feature_name = feature_name + "_pretokenized"
1828+
neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
17151829
if feature_name in example and pretokenized_feature_name not in example:
1716-
s = vocabulary[feature_name].decode(example[feature_name].tolist())
1830+
ids = example[feature_name].tolist()
1831+
1832+
neg_ids = [abs(i) for i in ids if i < 0]
1833+
ids = [i for i in ids if i > 0]
1834+
1835+
s = vocabulary[feature_name].decode(ids)
17171836
example[pretokenized_feature_name] = s
1837+
neg_s = vocabulary[feature_name].decode(neg_ids)
1838+
example[neg_pretokenized_feature_name] = neg_s
17181839

17191840
if not added_pretokenized[feature_name]:
17201841
added_pretokenized[feature_name] = True
@@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17301851
sequence_length, model_dir, eval_checkpoint_step,
17311852
inputs=gin.REQUIRED, targets=gin.REQUIRED,
17321853
score_postprocess_fn=gin.REQUIRED, eos_id=1,
1733-
score_eos=True):
1854+
score_eos=True,
1855+
score_with_estimator_fn=score_with_estimator):
17341856
"""Compute log likelihoods per example and write to a text file.
17351857
17361858
inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
17611883
score_eos: a boolean - whether to score the final eos token of each line
17621884
If this is set to false, the scores can be interpreted as prefix
17631885
log-likelihoods
1886+
score_with_estimator_fn: a function to run scoring with the estimator.
17641887
Returns:
17651888
a list of floats
17661889
"""
@@ -1806,7 +1929,7 @@ def input_fn(params):
18061929
dataset = dataset.batch(batch_size, drop_remainder=True)
18071930
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
18081931

1809-
return score_with_estimator(
1932+
return score_with_estimator_fn(
18101933
estimator, input_fn, eval_checkpoint_step, model_dir,
18111934
vocabulary, score_postprocess_fn, len(targets))
18121935

@@ -1815,7 +1938,8 @@ def input_fn(params):
18151938
def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18161939
model_dir, eval_checkpoint_step, dataset_split,
18171940
score_dataset_fn=None,
1818-
score_postprocess_fn=gin.REQUIRED):
1941+
score_postprocess_fn=gin.REQUIRED,
1942+
score_with_estimator_fn=score_with_estimator):
18191943
"""Compute log likelihoods per example and write to a text file.
18201944
18211945
The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18371961
See `eval_dataset_fn` argument to `eval_model` for details.
18381962
score_postprocess_fn: Function that takes in model outputs and
18391963
post-processes then returns then.
1964+
score_with_estimator_fn: a function to run scoring with the estimator.
18401965
18411966
Returns:
18421967
scores: a list of floats, the log likelihood scores
@@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
18501975
input_fn = _get_combined_dataset_input_fn(
18511976
scoring_datasets, batch_size, sequence_length)
18521977

1853-
return score_with_estimator(
1978+
return score_with_estimator_fn(
18541979
estimator, input_fn, eval_checkpoint_step, model_dir,
1855-
vocabulary, score_postprocess_fn, None)
1980+
vocabulary, score_postprocess_fn)
18561981

18571982

18581983
def get_estimator(model_type, vocabulary, mesh_shape,
@@ -2093,7 +2218,8 @@ def eval_model(estimator,
20932218
eval_checkpoint_step,
20942219
eval_with_score=False,
20952220
output_eval_examples=True,
2096-
eval_dir_suffix=None):
2221+
eval_dir_suffix=None,
2222+
score_with_estimator_fn=score_with_estimator):
20972223
"""Eval a Mesh-TF model.
20982224
20992225
Args:
@@ -2137,6 +2263,7 @@ def eval_model(estimator,
21372263
of the eval examples in plaintext to eval_summary_dir.
21382264
eval_dir_suffix: string, if not None then will appended to the
21392265
eval_summary_dir.
2266+
score_with_estimator_fn: a function to run scoring with the estimator.
21402267
"""
21412268
if eval_dataset_fn is None:
21422269
raise ValueError("Must provide eval_dataset_fn through gin for eval.")
@@ -2248,7 +2375,7 @@ def eval_model(estimator,
22482375
tf.logging.info("Checkpoint path %s" % checkpoint_path)
22492376
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
22502377
if eval_with_score:
2251-
outputs, _ = score_with_estimator(
2378+
outputs, _ = score_with_estimator_fn(
22522379
estimator, input_fn, global_step, model_dir, vocabulary,
22532380
num_examples=sum(len(cex) for cex in cached_examples.values()))
22542381
else:

0 commit comments

Comments
 (0)