20
20
from __future__ import print_function
21
21
22
22
import collections
23
+ import contextlib
23
24
import json
24
25
import os
25
26
import random
35
36
36
37
import tensorflow as tf
37
38
39
+ from tensorflow .contrib .tpu .python .tpu import tpu_estimator
38
40
from tensorflow .core .protobuf import rewriter_config_pb2
39
41
from tensorflow .python import debug
40
42
@@ -285,7 +287,9 @@ def create_estimator(model_name,
285
287
decode_hparams = None ,
286
288
use_tpu = False ,
287
289
use_tpu_estimator = False ,
288
- use_xla = False ):
290
+ use_xla = False ,
291
+ export_saved_model_api_version = 1 ,
292
+ use_guarantee_const_getter = False ):
289
293
"""Create a T2T Estimator."""
290
294
model_fn = t2t_model .T2TModel .make_estimator_model_fn (
291
295
model_name , hparams , decode_hparams = decode_hparams , use_tpu = use_tpu )
@@ -307,14 +311,66 @@ def create_estimator(model_name,
307
311
if decode_hparams and run_config .tpu_config :
308
312
decode_hparams .add_hparam ("iterations_per_loop" ,
309
313
run_config .tpu_config .iterations_per_loop )
314
+ if export_saved_model_api_version == 1 :
315
+ api_version_enum_name = tpu_estimator .ExportSavedModelApiVersion .V1
316
+ estimator_model_fn = model_fn
317
+ elif export_saved_model_api_version == 2 :
318
+ api_version_enum_name = tpu_estimator .ExportSavedModelApiVersion .V2
319
+
320
+ def maybe_use_guarantee_const_getter_model_fn (features , labels , mode ,
321
+ params ):
322
+ """Wrapper model_fn with guarantee_const getter."""
323
+ if not use_guarantee_const_getter :
324
+ return model_fn (features , labels , mode , params )
325
+
326
+ # It marks all weights as constant, which may improves TPU inference
327
+ # performance because it prevents the weights being transferred to the
328
+ # TPU. It will increase HBM "program" usage and reduce HBM "arguments"
329
+ # usage during TPU model serving.
330
+ def guarantee_const_getter (getter , name , * args , ** kwargs ):
331
+ with tf .control_dependencies (None ):
332
+ return tf .guarantee_const (
333
+ getter (name , * args , ** kwargs ), name = name + "/GuaranteeConst" )
334
+
335
+ @contextlib .contextmanager
336
+ def guarantee_const_scope ():
337
+ var_scope = tf .get_variable_scope ()
338
+ prev_custom_getter = var_scope .custom_getter
339
+ prev_caching_device = var_scope .caching_device
340
+ var_scope .set_custom_getter (guarantee_const_getter )
341
+ var_scope .set_caching_device (lambda op : op .device )
342
+ yield
343
+ var_scope .set_custom_getter (prev_custom_getter )
344
+ var_scope .set_caching_device (prev_caching_device )
345
+
346
+ with guarantee_const_scope ():
347
+ return model_fn (features , labels , mode , params )
348
+
349
+ def tpu_model_fn (features , labels , mode , params ):
350
+ """Wrapper model_fn with tpu.rewrite / TPUPartitionedCall."""
351
+ if mode == tf .estimator .ModeKeys .PREDICT and params ["use_tpu" ]:
352
+ return tpu_estimator .model_fn_inference_on_tpu (
353
+ maybe_use_guarantee_const_getter_model_fn ,
354
+ features = features ,
355
+ labels = labels ,
356
+ config = None ,
357
+ params = params ,
358
+ batch_config = None )
359
+ else :
360
+ return model_fn (features , labels , mode , params )
361
+
362
+ estimator_model_fn = tpu_model_fn
363
+ else :
364
+ raise ValueError ("Flag export_saved_model_api_version must be 1 or 2." )
310
365
estimator = tf .contrib .tpu .TPUEstimator (
311
- model_fn = model_fn ,
366
+ model_fn = estimator_model_fn ,
312
367
model_dir = run_config .model_dir ,
313
368
config = run_config ,
314
369
use_tpu = use_tpu ,
315
370
train_batch_size = batch_size ,
316
371
eval_batch_size = batch_size if "eval" in schedule else None ,
317
- predict_batch_size = predict_batch_size )
372
+ predict_batch_size = predict_batch_size ,
373
+ export_saved_model_api_version = api_version_enum_name )
318
374
else :
319
375
estimator = tf .estimator .Estimator (
320
376
model_fn = model_fn ,
@@ -633,6 +689,8 @@ def create_experiment(
633
689
use_tpu = False ,
634
690
use_tpu_estimator = False ,
635
691
use_xla = False ,
692
+ export_saved_model_api_version = 1 ,
693
+ use_guarantee_const_getter = False ,
636
694
additional_train_hooks = None ,
637
695
additional_eval_hooks = None ,
638
696
warm_start_from = None ,
@@ -668,7 +726,9 @@ def create_experiment(
668
726
decode_hparams = decode_hparams ,
669
727
use_tpu = use_tpu ,
670
728
use_tpu_estimator = use_tpu_estimator ,
671
- use_xla = use_xla )
729
+ use_xla = use_xla ,
730
+ export_saved_model_api_version = export_saved_model_api_version ,
731
+ use_guarantee_const_getter = use_guarantee_const_getter )
672
732
673
733
# Input fns from Problem
674
734
problem = hparams .problem
0 commit comments