26
26
27
27
import functools
28
28
import itertools
29
+ import math
29
30
import os
30
31
import random
31
32
import re
33
+ import time
32
34
33
35
import gin
34
36
import gin .tf
38
40
from mesh_tensorflow .transformer import learning_rate_schedules
39
41
from mesh_tensorflow .transformer import transformer
40
42
import numpy as np
43
+ import pandas as pd
41
44
import pkg_resources
42
45
import six
43
46
import tensorflow .compat .v1 as tf
@@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0):
1654
1657
return scores
1655
1658
1656
1659
1660
+ @gin .configurable
1661
+ def save_scores_to_tfrecords (
1662
+ results , vocabulary , scores_filename , shard_idx = 0 , save_ids_only = False ):
1663
+ """Processes results from scoring examples and saves them to tfrecords files.
1664
+
1665
+ Args:
1666
+ results: list of dictionaries containing the results for each scored
1667
+ example.
1668
+ vocabulary: a function that that returns a tf.data.Dataset with examples
1669
+ containing the string field 'targets' and optionally the field 'inputs'
1670
+ scores_filename: a string (path of file to write scores to).
1671
+ shard_idx: an integer indicating the current index of the file for sharding.
1672
+ save_ids_only: if true, save the ID that is prepended to the inputs.
1673
+ """
1674
+ results = _maybe_add_pretokenized_features (results , vocabulary )
1675
+ scores = [r .get ("scores" , 0.0 ) for r in results ]
1676
+ targets = [r .get ("targets_pretokenized" , r ["targets" ]) for r in results ]
1677
+ inputs = [r .get ("targets_neg_pretokenized" , "" ) for r in results ]
1678
+
1679
+ if save_ids_only :
1680
+ inputs = [r .split (" " , 1 )[0 ] for r in inputs ]
1681
+
1682
+ table_path = "{}_{}.tfrecord" .format (scores_filename , shard_idx )
1683
+ tf .logging .info ("Saving results to {}" .format (table_path ))
1684
+
1685
+ with tf .io .TFRecordWriter (table_path ) as file_writer :
1686
+ for input_ , target , score in zip (inputs , targets , scores ):
1687
+ record_bytes = tf .train .Example (
1688
+ features = tf .train .Features (
1689
+ feature = {
1690
+ "input" :
1691
+ tf .train .Feature (
1692
+ bytes_list = tf .train .BytesList (
1693
+ value = [bytes (input_ , "utf8" )])),
1694
+ "target" :
1695
+ tf .train .Feature (
1696
+ bytes_list = tf .train .BytesList (
1697
+ value = [bytes (target , "utf8" )])),
1698
+ "score" :
1699
+ tf .train .Feature (
1700
+ float_list = tf .train .FloatList (value = [score ])),
1701
+ })).SerializeToString ()
1702
+ file_writer .write (record_bytes )
1703
+
1704
+
1705
+ @gin .configurable
1657
1706
def score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
1658
1707
vocabulary , score_postprocess_fn = save_scores ,
1659
1708
num_examples = None ):
@@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
1691
1740
return score_postprocess_fn (results , vocabulary )
1692
1741
1693
1742
1743
+ @gin .configurable
1744
+ def score_with_estimator_lazy (
1745
+ estimator , input_fn , eval_checkpoint_step , model_dir ,
1746
+ vocabulary , score_postprocess_fn = save_scores_to_tfrecords ,
1747
+ num_examples = None , num_examples_per_shard = 10000 ):
1748
+ """Score each example returned by input_fn lazily.
1749
+
1750
+ Args:
1751
+ estimator: a TPUEstimator
1752
+ input_fn: a function that that returns a tf.data.Dataset with examples
1753
+ containing the string field 'targets' and optionally the field 'inputs'
1754
+ eval_checkpoint_step: int, list of ints, or None, see `eval_model`
1755
+ docstring.
1756
+ model_dir: string, estimator model_dir
1757
+ vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
1758
+ targets_vocabulary) tuple
1759
+ score_postprocess_fn: a function that takes in model outputs and
1760
+ post-processes, saves, and returns them.
1761
+ num_examples: int, the total # of examples being scored, None if unknown
1762
+ num_examples_per_shard: int, the number of examples per file shard.
1763
+
1764
+ Returns:
1765
+ a list of floats
1766
+ """
1767
+ if num_examples is not None :
1768
+ num_shards = math .ceil (num_examples / num_examples_per_shard )
1769
+ else :
1770
+ num_shards = None
1771
+ tf .logging .info (
1772
+ "Scoring {} examples with {} shards at {} examples per shard" .format (
1773
+ num_examples , num_shards , num_examples_per_shard ))
1774
+
1775
+ checkpoint_path , = get_checkpoint_iterator (
1776
+ eval_checkpoint_step , model_dir )
1777
+ result_iter = estimator .predict (input_fn , checkpoint_path = checkpoint_path )
1778
+
1779
+ start = time .time ()
1780
+ results = []
1781
+ shard_idx = 0
1782
+
1783
+ for i , result in enumerate (result_iter ):
1784
+ results .append (result )
1785
+ num_results = len (results )
1786
+ exceeded_num_examples = num_examples is not None and i >= num_examples
1787
+
1788
+ if num_results >= num_examples_per_shard or exceeded_num_examples :
1789
+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1790
+
1791
+ elapsed = time .time () - start
1792
+ tf .logging .info (
1793
+ "Scored {} results in {} s, {} examples/s for shard {}" .format (
1794
+ num_results , elapsed , num_results / elapsed , shard_idx ))
1795
+
1796
+ results = []
1797
+ shard_idx += 1
1798
+ start = time .time ()
1799
+
1800
+ if exceeded_num_examples :
1801
+ break
1802
+
1803
+ if results :
1804
+ score_postprocess_fn (results , vocabulary , shard_idx = shard_idx )
1805
+
1806
+
1694
1807
def _maybe_add_pretokenized_features (examples , vocabulary ):
1695
1808
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
1696
1809
@@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
1712
1825
for example in examples :
1713
1826
for feature_name in ["inputs" , "targets" ]:
1714
1827
pretokenized_feature_name = feature_name + "_pretokenized"
1828
+ neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
1715
1829
if feature_name in example and pretokenized_feature_name not in example :
1716
- s = vocabulary [feature_name ].decode (example [feature_name ].tolist ())
1830
+ ids = example [feature_name ].tolist ()
1831
+
1832
+ neg_ids = [abs (i ) for i in ids if i < 0 ]
1833
+ ids = [i for i in ids if i > 0 ]
1834
+
1835
+ s = vocabulary [feature_name ].decode (ids )
1717
1836
example [pretokenized_feature_name ] = s
1837
+ neg_s = vocabulary [feature_name ].decode (neg_ids )
1838
+ example [neg_pretokenized_feature_name ] = neg_s
1718
1839
1719
1840
if not added_pretokenized [feature_name ]:
1720
1841
added_pretokenized [feature_name ] = True
@@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
1730
1851
sequence_length , model_dir , eval_checkpoint_step ,
1731
1852
inputs = gin .REQUIRED , targets = gin .REQUIRED ,
1732
1853
score_postprocess_fn = gin .REQUIRED , eos_id = 1 ,
1733
- score_eos = True ):
1854
+ score_eos = True ,
1855
+ score_with_estimator_fn = score_with_estimator ):
1734
1856
"""Compute log likelihoods per example and write to a text file.
1735
1857
1736
1858
inputs & targets must either be the same length (in lines) or have inputs
@@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
1761
1883
score_eos: a boolean - whether to score the final eos token of each line
1762
1884
If this is set to false, the scores can be interpreted as prefix
1763
1885
log-likelihoods
1886
+ score_with_estimator_fn: a function to run scoring with the estimator.
1764
1887
Returns:
1765
1888
a list of floats
1766
1889
"""
@@ -1806,7 +1929,7 @@ def input_fn(params):
1806
1929
dataset = dataset .batch (batch_size , drop_remainder = True )
1807
1930
return dataset .prefetch (tf .data .experimental .AUTOTUNE )
1808
1931
1809
- return score_with_estimator (
1932
+ return score_with_estimator_fn (
1810
1933
estimator , input_fn , eval_checkpoint_step , model_dir ,
1811
1934
vocabulary , score_postprocess_fn , len (targets ))
1812
1935
@@ -1815,7 +1938,8 @@ def input_fn(params):
1815
1938
def score_from_dataset (estimator , vocabulary , batch_size , sequence_length ,
1816
1939
model_dir , eval_checkpoint_step , dataset_split ,
1817
1940
score_dataset_fn = None ,
1818
- score_postprocess_fn = gin .REQUIRED ):
1941
+ score_postprocess_fn = gin .REQUIRED ,
1942
+ score_with_estimator_fn = score_with_estimator ):
1819
1943
"""Compute log likelihoods per example and write to a text file.
1820
1944
1821
1945
The function returns a list of floats representing the log-likelihood of the
@@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
1837
1961
See `eval_dataset_fn` argument to `eval_model` for details.
1838
1962
score_postprocess_fn: Function that takes in model outputs and
1839
1963
post-processes then returns then.
1964
+ score_with_estimator_fn: a function to run scoring with the estimator.
1840
1965
1841
1966
Returns:
1842
1967
scores: a list of floats, the log likelihood scores
@@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
1850
1975
input_fn = _get_combined_dataset_input_fn (
1851
1976
scoring_datasets , batch_size , sequence_length )
1852
1977
1853
- return score_with_estimator (
1978
+ return score_with_estimator_fn (
1854
1979
estimator , input_fn , eval_checkpoint_step , model_dir ,
1855
- vocabulary , score_postprocess_fn , None )
1980
+ vocabulary , score_postprocess_fn )
1856
1981
1857
1982
1858
1983
def get_estimator (model_type , vocabulary , mesh_shape ,
@@ -2093,7 +2218,8 @@ def eval_model(estimator,
2093
2218
eval_checkpoint_step ,
2094
2219
eval_with_score = False ,
2095
2220
output_eval_examples = True ,
2096
- eval_dir_suffix = None ):
2221
+ eval_dir_suffix = None ,
2222
+ score_with_estimator_fn = score_with_estimator ):
2097
2223
"""Eval a Mesh-TF model.
2098
2224
2099
2225
Args:
@@ -2137,6 +2263,7 @@ def eval_model(estimator,
2137
2263
of the eval examples in plaintext to eval_summary_dir.
2138
2264
eval_dir_suffix: string, if not None then will appended to the
2139
2265
eval_summary_dir.
2266
+ score_with_estimator_fn: a function to run scoring with the estimator.
2140
2267
"""
2141
2268
if eval_dataset_fn is None :
2142
2269
raise ValueError ("Must provide eval_dataset_fn through gin for eval." )
@@ -2248,7 +2375,7 @@ def eval_model(estimator,
2248
2375
tf .logging .info ("Checkpoint path %s" % checkpoint_path )
2249
2376
global_step = int (get_step_from_checkpoint_path (checkpoint_path ))
2250
2377
if eval_with_score :
2251
- outputs , _ = score_with_estimator (
2378
+ outputs , _ = score_with_estimator_fn (
2252
2379
estimator , input_fn , global_step , model_dir , vocabulary ,
2253
2380
num_examples = sum (len (cex ) for cex in cached_examples .values ()))
2254
2381
else :
0 commit comments