Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d38f343

Browse files
ziycopybara-github
authored andcommitted
[T2T] Fixed high usage of TPU HBM "Arguments" during serving
- Added flag for export_saved_model_api_version (default to 1) - Added maybe_use_guarantee_const_getter_model_fn and use_guarantee_const_getter flag. It marks all weights as constant, which may improves TPU inference performance because it prevents the weights being transferred to the TPU. It will increase HBM "program" usage and reduce HBM "arguments" usage during TPU model serving. PiperOrigin-RevId: 256026810
1 parent 1fb49ea commit d38f343

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@
5858
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU.")
5959
flags.DEFINE_bool("use_tpu_estimator", False, "Whether to use TPUEstimator. "
6060
"This is always enabled when use_tpu is True.")
61+
flags.DEFINE_integer("export_saved_model_api_version", 1,
62+
"ExportSavedModelApiVersion, 1 (V1, default) or 2 (V2). "
63+
"Default V2 uses model_fn_inference_on_tpu for rewrite."
64+
"Flag use_guarantee_const is only enabled in V2.")
65+
flags.DEFINE_bool("use_guarantee_const_getter", False,
66+
"Whether to use GuaranteeConst Ops to mark all weights as "
67+
"constant. It may improve TPU inference performance and "
68+
"reduce HBM arguments usage. Only available when "
69+
"export_saved_model_api_version=2 and use_tpu=True.")
6170
flags.DEFINE_bool("xla_compile", False,
6271
"Whether to use XLA to compile model_fn.")
6372
flags.DEFINE_integer("xla_jit_level", -1,
@@ -197,6 +206,8 @@ def create_experiment_fn():
197206
use_tpu=FLAGS.use_tpu,
198207
use_tpu_estimator=FLAGS.use_tpu_estimator,
199208
use_xla=FLAGS.xla_compile,
209+
export_saved_model_api_version=FLAGS.export_saved_model_api_version,
210+
use_guarantee_const_getter=FLAGS.use_guarantee_const_getter,
200211
warm_start_from=FLAGS.warm_start_from,
201212
decode_from_file=FLAGS.decode_from_file,
202213
decode_to_file=FLAGS.decode_to_file,

tensor2tensor/serving/export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def create_estimator(run_config, hparams):
8282
hparams,
8383
run_config,
8484
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
85-
use_tpu=FLAGS.use_tpu)
85+
use_tpu=FLAGS.use_tpu,
86+
export_saved_model_api_version=FLAGS.export_saved_model_api_version,
87+
use_guarantee_const_getter=FLAGS.use_guarantee_const_getter)
8688

8789

8890
def create_hparams():

tensor2tensor/utils/trainer_lib.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from __future__ import print_function
2121

2222
import collections
23+
import contextlib
2324
import json
2425
import os
2526
import random
@@ -35,6 +36,7 @@
3536

3637
import tensorflow as tf
3738

39+
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
3840
from tensorflow.core.protobuf import rewriter_config_pb2
3941
from tensorflow.python import debug
4042

@@ -285,7 +287,9 @@ def create_estimator(model_name,
285287
decode_hparams=None,
286288
use_tpu=False,
287289
use_tpu_estimator=False,
288-
use_xla=False):
290+
use_xla=False,
291+
export_saved_model_api_version=1,
292+
use_guarantee_const_getter=False):
289293
"""Create a T2T Estimator."""
290294
model_fn = t2t_model.T2TModel.make_estimator_model_fn(
291295
model_name, hparams, decode_hparams=decode_hparams, use_tpu=use_tpu)
@@ -307,14 +311,66 @@ def create_estimator(model_name,
307311
if decode_hparams and run_config.tpu_config:
308312
decode_hparams.add_hparam("iterations_per_loop",
309313
run_config.tpu_config.iterations_per_loop)
314+
if export_saved_model_api_version == 1:
315+
api_version_enum_name = tpu_estimator.ExportSavedModelApiVersion.V1
316+
estimator_model_fn = model_fn
317+
elif export_saved_model_api_version == 2:
318+
api_version_enum_name = tpu_estimator.ExportSavedModelApiVersion.V2
319+
320+
def maybe_use_guarantee_const_getter_model_fn(features, labels, mode,
321+
params):
322+
"""Wrapper model_fn with guarantee_const getter."""
323+
if not use_guarantee_const_getter:
324+
return model_fn(features, labels, mode, params)
325+
326+
# It marks all weights as constant, which may improves TPU inference
327+
# performance because it prevents the weights being transferred to the
328+
# TPU. It will increase HBM "program" usage and reduce HBM "arguments"
329+
# usage during TPU model serving.
330+
def guarantee_const_getter(getter, name, *args, **kwargs):
331+
with tf.control_dependencies(None):
332+
return tf.guarantee_const(
333+
getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
334+
335+
@contextlib.contextmanager
336+
def guarantee_const_scope():
337+
var_scope = tf.get_variable_scope()
338+
prev_custom_getter = var_scope.custom_getter
339+
prev_caching_device = var_scope.caching_device
340+
var_scope.set_custom_getter(guarantee_const_getter)
341+
var_scope.set_caching_device(lambda op: op.device)
342+
yield
343+
var_scope.set_custom_getter(prev_custom_getter)
344+
var_scope.set_caching_device(prev_caching_device)
345+
346+
with guarantee_const_scope():
347+
return model_fn(features, labels, mode, params)
348+
349+
def tpu_model_fn(features, labels, mode, params):
350+
"""Wrapper model_fn with tpu.rewrite / TPUPartitionedCall."""
351+
if mode == tf.estimator.ModeKeys.PREDICT and params["use_tpu"]:
352+
return tpu_estimator.model_fn_inference_on_tpu(
353+
maybe_use_guarantee_const_getter_model_fn,
354+
features=features,
355+
labels=labels,
356+
config=None,
357+
params=params,
358+
batch_config=None)
359+
else:
360+
return model_fn(features, labels, mode, params)
361+
362+
estimator_model_fn = tpu_model_fn
363+
else:
364+
raise ValueError("Flag export_saved_model_api_version must be 1 or 2.")
310365
estimator = tf.contrib.tpu.TPUEstimator(
311-
model_fn=model_fn,
366+
model_fn=estimator_model_fn,
312367
model_dir=run_config.model_dir,
313368
config=run_config,
314369
use_tpu=use_tpu,
315370
train_batch_size=batch_size,
316371
eval_batch_size=batch_size if "eval" in schedule else None,
317-
predict_batch_size=predict_batch_size)
372+
predict_batch_size=predict_batch_size,
373+
export_saved_model_api_version=api_version_enum_name)
318374
else:
319375
estimator = tf.estimator.Estimator(
320376
model_fn=model_fn,
@@ -633,6 +689,8 @@ def create_experiment(
633689
use_tpu=False,
634690
use_tpu_estimator=False,
635691
use_xla=False,
692+
export_saved_model_api_version=1,
693+
use_guarantee_const_getter=False,
636694
additional_train_hooks=None,
637695
additional_eval_hooks=None,
638696
warm_start_from=None,
@@ -668,7 +726,9 @@ def create_experiment(
668726
decode_hparams=decode_hparams,
669727
use_tpu=use_tpu,
670728
use_tpu_estimator=use_tpu_estimator,
671-
use_xla=use_xla)
729+
use_xla=use_xla,
730+
export_saved_model_api_version=export_saved_model_api_version,
731+
use_guarantee_const_getter=use_guarantee_const_getter)
672732

673733
# Input fns from Problem
674734
problem = hparams.problem

0 commit comments

Comments
 (0)