Skip to content

Commit aee49bb

Browse files
authored
Merged commit includes the following changes: (tensorflow#7357)
261202754 by hongkuny<hongkuny@google.com>: Use enable_xla flag for classifier and squad, so xla option is exposed to users. -- PiperOrigin-RevId: 261202754
1 parent 8754280 commit aee49bb

File tree

5 files changed

+20
-24
lines changed

5 files changed

+20
-24
lines changed

official/bert/benchmark/bert_benchmark.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from official.bert import run_classifier
3434
from official.bert.benchmark import benchmark_utils
3535
from official.utils.misc import distribution_utils
36-
from official.utils.misc import keras_utils
3736

3837
# pylint: disable=line-too-long
3938
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
@@ -55,7 +54,7 @@ def __init__(self, output_dir=None):
5554
self.num_steps_per_epoch = None
5655

5756
@flagsaver.flagsaver
58-
def _run_bert_classifier(self, callbacks=None, use_ds=True, enable_xla=False):
57+
def _run_bert_classifier(self, callbacks=None, use_ds=True):
5958
"""Starts BERT classification task."""
6059
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
6160
input_meta_data = json.loads(reader.read().decode('utf-8'))
@@ -73,8 +72,6 @@ def _run_bert_classifier(self, callbacks=None, use_ds=True, enable_xla=False):
7372
strategy = distribution_utils.get_distribution_strategy(
7473
distribution_strategy='mirrored' if use_ds else 'off',
7574
num_gpus=self.num_gpus)
76-
# TODO(hongkuny): Enable XLA once we are confident with its performance.
77-
keras_utils.set_config_v2(enable_xla)
7875

7976
steps_per_loop = 1
8077

@@ -119,13 +116,10 @@ def _run_and_report_benchmark(self,
119116
training_summary_path,
120117
min_accuracy=0,
121118
max_accuracy=1,
122-
use_ds=True,
123-
enable_xla=False):
119+
use_ds=True):
124120
"""Starts BERT performance benchmark test."""
125-
126121
start_time_sec = time.time()
127-
self._run_bert_classifier(
128-
callbacks=[self.timer_callback], use_ds=use_ds, enable_xla=enable_xla)
122+
self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds)
129123
wall_time_sec = time.time() - start_time_sec
130124

131125
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
@@ -168,9 +162,10 @@ def benchmark_1_gpu_mrpc_xla(self):
168162
FLAGS.bert_config_file = self.bert_config_file
169163
FLAGS.train_batch_size = 4
170164
FLAGS.eval_batch_size = 4
165+
FLAGS.enable_xla = True
171166

172167
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
173-
self._run_and_report_benchmark(summary_path, enable_xla=True)
168+
self._run_and_report_benchmark(summary_path)
174169

175170
def benchmark_1_gpu_mrpc_no_dist_strat(self):
176171
"""Test BERT model performance with 1 GPU, no distribution strategy."""
@@ -253,13 +248,11 @@ def __init__(self, output_dir=None, **kwargs):
253248
def _run_and_report_benchmark(self,
254249
training_summary_path,
255250
min_accuracy=0.84,
256-
max_accuracy=0.88,
257-
enable_xla=False):
251+
max_accuracy=0.88):
258252
"""Starts BERT accuracy benchmark test."""
259253

260254
start_time_sec = time.time()
261-
self._run_bert_classifier(
262-
callbacks=[self.timer_callback], enable_xla=enable_xla)
255+
self._run_bert_classifier(callbacks=[self.timer_callback])
263256
wall_time_sec = time.time() - start_time_sec
264257

265258
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
@@ -296,9 +289,9 @@ def benchmark_8_gpu_mrpc_xla(self):
296289
"""Run BERT model accuracy test with 8 GPUs with XLA."""
297290
self._setup()
298291
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
299-
292+
FLAGS.enable_xla = True
300293
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
301-
self._run_and_report_benchmark(summary_path, enable_xla=True)
294+
self._run_and_report_benchmark(summary_path)
302295

303296

304297
if __name__ == '__main__':

official/bert/benchmark/bert_squad_benchmark.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from official.bert.benchmark import benchmark_utils
3333
from official.bert.benchmark import squad_evaluate_v1_1
3434
from official.utils.misc import distribution_utils
35-
from official.utils.misc import keras_utils
3635

3736
# pylint: disable=line-too-long
3837
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
@@ -131,10 +130,8 @@ def _setup(self):
131130

132131
def _run_and_report_benchmark(self,
133132
use_ds=True,
134-
enable_xla=False,
135133
run_eagerly=False):
136134
"""Runs the benchmark and reports various metrics."""
137-
keras_utils.set_config_v2(enable_xla)
138135
start_time_sec = time.time()
139136
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
140137
wall_time_sec = time.time() - start_time_sec
@@ -164,8 +161,9 @@ def benchmark_1_gpu_xla(self):
164161
self.num_gpus = 1
165162
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla_squad')
166163
FLAGS.train_batch_size = 4
164+
FLAGS.enable_xla = True
167165

168-
self._run_and_report_benchmark(enable_xla=True)
166+
self._run_and_report_benchmark()
169167

170168
def benchmark_1_gpu_no_dist_strat(self):
171169
"""Tests BERT SQuAD model performance with 1 GPU without DS."""
@@ -291,10 +289,8 @@ def _setup(self):
291289

292290
def _run_and_report_benchmark(self,
293291
use_ds=True,
294-
enable_xla=False,
295292
run_eagerly=False):
296293
"""Runs the benchmark and reports various metrics."""
297-
keras_utils.set_config_v2(enable_xla)
298294
start_time_sec = time.time()
299295
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
300296
self._evaluate_squad()
@@ -348,8 +344,9 @@ def benchmark_8_gpu_xla(self):
348344
self.num_gpus = 8
349345
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
350346
FLAGS.train_batch_size = 32
347+
FLAGS.enable_xla = True
351348

352-
self._run_and_report_benchmark(enable_xla=True)
349+
self._run_and_report_benchmark()
353350

354351

355352
if __name__ == '__main__':

official/bert/common_flags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def define_common_bert_flags():
5858
loss_scale=True,
5959
all_reduce_alg=False,
6060
num_packs=False,
61-
enable_xla=False
61+
enable_xla=True
6262
)
6363

6464

official/bert/run_classifier.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from official.bert import modeling
3737
from official.bert import optimization
3838
from official.bert import tpu_lib
39+
from official.utils.misc import keras_utils
3940

4041
flags.DEFINE_enum(
4142
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
@@ -174,6 +175,8 @@ def run_bert(strategy, input_meta_data):
174175

175176
if FLAGS.mode != 'train_and_eval':
176177
raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
178+
# Enables XLA in Session Config. Should not be set for TPU.
179+
keras_utils.set_config_v2(FLAGS.enable_xla)
177180

178181
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
179182
epochs = FLAGS.num_train_epochs

official/bert/run_squad.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from official.bert import squad_lib
3838
from official.bert import tokenization
3939
from official.bert import tpu_lib
40+
from official.utils.misc import keras_utils
4041

4142
flags.DEFINE_bool('do_train', False, 'Whether to run training.')
4243
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
@@ -181,6 +182,8 @@ def train_squad(strategy,
181182
if strategy:
182183
logging.info('Training using customized training loop with distribution'
183184
' strategy.')
185+
# Enables XLA in Session Config. Should not be set for TPU.
186+
keras_utils.set_config_v2(FLAGS.enable_xla)
184187

185188
use_float16 = common_flags.use_float16()
186189
if use_float16:

0 commit comments

Comments
 (0)