Skip to content

Commit 50dfb31

Browse files
authored
Add experimental tf.data sleep tuning for better performance (tensorflow#6634)
* Introduce a short sleep before ds.prefetch in tf.data. * Further limit dataset threads to reduce CPU contention * Tuned dataset sleep time * Rename dataset sleep flag; enable it only for Keras Graph mode
1 parent 0d76b69 commit 50dfb31

File tree

3 files changed

+85
-4
lines changed

3 files changed

+85
-4
lines changed

official/resnet/keras/keras_common.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,11 @@ def set_gpu_thread_mode_and_count(flags_obj):
150150
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
151151
# private threads and memory copy threads.
152152
total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus
153-
num_mem_copy_threads = flags_obj.num_gpus
153+
num_runtime_threads = flags_obj.num_gpus
154154
if not flags_obj.datasets_num_private_threads:
155-
flags_obj.datasets_num_private_threads = (cpu_count - total_gpu_thread_count
156-
- num_mem_copy_threads)
155+
flags_obj.datasets_num_private_threads = min(
156+
cpu_count - total_gpu_thread_count - num_runtime_threads,
157+
flags_obj.num_gpus * 8)
157158
tf.compat.v1.logging.info('Set datasets_num_private_threads to %s',
158159
flags_obj.datasets_num_private_threads)
159160

@@ -283,6 +284,13 @@ def define_keras_flags():
283284
'triggers the profiler to process 3 steps, starting from the 2nd step. '
284285
'Note that profiler has a non-trivial performance overhead, and the '
285286
'output file can be gigantic if profiling many steps.')
287+
flags.DEFINE_boolean(
288+
name='data_prefetch_with_slack', default=False,
289+
help='Add a small delay in tf.data prefetch to prioritize memory copy of '
290+
'other tensors over the data minibatch for the (T+1)th step. It should '
291+
'help improve performance using EagerIterator and function. The codepath '
292+
'when enabling this feature is experimental and will be removed once the '
293+
'corresponding performance features are fully supported in TensorFlow.')
286294

287295

288296
def get_synth_input_fn(height, width, num_channels, num_classes,
@@ -341,6 +349,12 @@ def is_v2_0():
341349
return tf.__version__.startswith('2')
342350

343351

352+
def data_prefetch_with_slack():
353+
"""Use unstable code for perf tuning purposes."""
354+
if not FLAGS.use_synthetic_data:
355+
_monkey_patch_org_create_device_dataset()
356+
357+
344358
def _monkey_patch_org_assert_broadcastable():
345359
"""Monkey-patch `assert_broadcast` op to avoid OOM when enabling XLA."""
346360
def no_op_assert_broadcastable(weights, values):
@@ -362,3 +376,29 @@ def _undo_monkey_patch_org_assert_broadcastable():
362376
if hasattr(weights_broadcast_ops, 'org_assert_broadcastable'):
363377
weights_broadcast_ops.assert_broadcastable = (
364378
weights_broadcast_ops.org_assert_broadcastable)
379+
380+
381+
# TODO(haoyuzhang): remove this monkey patch when the "prefetch with slack"
382+
# feature is available in tf.data.
383+
def _monkey_patch_org_create_device_dataset():
384+
"""Monkey-patch `_create_device_dataset` method with delayed prefetch."""
385+
386+
import ast # pylint: disable=g-import-not-at-top
387+
import inspect # pylint: disable=g-import-not-at-top
388+
from tensorflow.python.data.ops import multi_device_iterator_ops # pylint: disable=g-import-not-at-top
389+
390+
tf.compat.v1.logging.info(
391+
'Using monkey-patched version of MultiDeviceIterator. It should be '
392+
'removed when the prefetch with slack feature is implemented in tf.data.')
393+
cls_multi_device_iterator = ast.parse(
394+
inspect.getsource(multi_device_iterator_ops.MultiDeviceIterator))
395+
org_create_device_dataset_code = inspect.getsource(
396+
multi_device_iterator_ops.MultiDeviceIterator._create_device_dataset) # pylint: disable=protected-access
397+
code_lines = org_create_device_dataset_code.split('\n')
398+
# Insert in reverse order to avoid line number shift by previous insertions
399+
code_lines.insert(5, ' ds = ds.apply(sleep_ops.sleep(11000))') # 11ms
400+
code_lines.insert(2, ' from tensorflow.python.data.experimental.ops import sleep as sleep_ops') # pylint: disable=line-too-long
401+
patched_code = '\n'.join(line[2:] for line in code_lines)
402+
cls_multi_device_iterator.body[0].body[2] = ast.parse(patched_code).body[0]
403+
exec(compile(cls_multi_device_iterator, '<string>', 'exec'), # pylint: disable=exec-used
404+
multi_device_iterator_ops.__dict__)

official/resnet/keras/keras_imagenet_benchmark.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,21 @@ def benchmark_xla_1_gpu_fp16(self):
251251
FLAGS.batch_size = 256
252252
self._run_and_report_benchmark()
253253

254+
def benchmark_xla_1_gpu_fp16_tweaked(self):
255+
"""Test Keras model with XLA, 1 GPU, fp16, and manual config tuning."""
256+
self._setup()
257+
258+
FLAGS.num_gpus = 1
259+
FLAGS.enable_eager = True
260+
FLAGS.enable_xla = True
261+
FLAGS.distribution_strategy = 'default'
262+
FLAGS.model_dir = self._get_model_dir('benchmark_xla_1_gpu_fp16_tweaked')
263+
FLAGS.dtype = 'fp16'
264+
FLAGS.batch_size = 256
265+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
266+
FLAGS.data_prefetch_with_slack = True
267+
self._run_and_report_benchmark()
268+
254269
def benchmark_xla_1_gpu_fp16_dynamic(self):
255270
"""Test Keras model with XLA, 1 GPU, fp16, and dynamic loss scaling."""
256271
self._setup()
@@ -313,6 +328,23 @@ def benchmark_graph_xla_1_gpu_fp16(self):
313328
FLAGS.batch_size = 256
314329
self._run_and_report_benchmark()
315330

331+
def benchmark_graph_xla_1_gpu_fp16_tweaked(self):
332+
"""Test Keras model in legacy graph mode with 1 GPU, fp16, XLA, and manual
333+
config tuning.
334+
"""
335+
self._setup()
336+
337+
FLAGS.num_gpus = 1
338+
FLAGS.enable_eager = False
339+
FLAGS.enable_xla = True
340+
FLAGS.distribution_strategy = 'default'
341+
FLAGS.model_dir = self._get_model_dir(
342+
'benchmark_graph_xla_1_gpu_fp16_tweaked')
343+
FLAGS.dtype = 'fp16'
344+
FLAGS.batch_size = 256
345+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
346+
self._run_and_report_benchmark()
347+
316348
def benchmark_8_gpu(self):
317349
"""Test Keras model with 8 GPUs."""
318350
self._setup()
@@ -334,6 +366,7 @@ def benchmark_8_gpu_tweaked(self):
334366
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
335367
FLAGS.batch_size = 128 * 8 # 8 GPUs
336368
FLAGS.datasets_num_private_threads = 14
369+
FLAGS.data_prefetch_with_slack = True
337370
self._run_and_report_benchmark()
338371

339372
def benchmark_xla_8_gpu(self):
@@ -371,6 +404,7 @@ def benchmark_8_gpu_fp16_tweaked(self):
371404
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
372405
FLAGS.batch_size = 256 * 8 # 8 GPUs
373406
FLAGS.tf_gpu_thread_mode = 'gpu_private'
407+
FLAGS.data_prefetch_with_slack = True
374408
self._run_and_report_benchmark()
375409

376410
def benchmark_8_gpu_fp16_dynamic_tweaked(self):
@@ -386,6 +420,7 @@ def benchmark_8_gpu_fp16_dynamic_tweaked(self):
386420
FLAGS.batch_size = 256 * 8 # 8 GPUs
387421
FLAGS.loss_scale = 'dynamic'
388422
FLAGS.tf_gpu_thread_mode = 'gpu_private'
423+
FLAGS.data_prefetch_with_slack = True
389424
self._run_and_report_benchmark()
390425

391426
def benchmark_xla_8_gpu_fp16(self):
@@ -412,7 +447,8 @@ def benchmark_xla_8_gpu_fp16_tweaked(self):
412447
FLAGS.distribution_strategy = 'default'
413448
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
414449
FLAGS.batch_size = 256 * 8 # 8 GPUs
415-
FLAGS.tf_gpu_thread_mode = 'gpu_private'
450+
# FLAGS.tf_gpu_thread_mode = 'gpu_private'
451+
FLAGS.data_prefetch_with_slack = True
416452
self._run_and_report_benchmark()
417453

418454
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
@@ -429,6 +465,7 @@ def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
429465
FLAGS.batch_size = 256 * 8 # 8 GPUs
430466
FLAGS.loss_scale = 'dynamic'
431467
FLAGS.tf_gpu_thread_mode = 'gpu_private'
468+
FLAGS.data_prefetch_with_slack = True
432469
self._run_and_report_benchmark()
433470

434471
def benchmark_xla_8_gpu_fp16_tensorboard_tweaked(self):
@@ -444,6 +481,7 @@ def benchmark_xla_8_gpu_fp16_tensorboard_tweaked(self):
444481
'benchmark_xla_8_gpu_fp16_tensorboard_tweaked')
445482
FLAGS.batch_size = 256 * 8 # 8 GPUs
446483
FLAGS.tf_gpu_thread_mode = 'gpu_private'
484+
FLAGS.data_prefetch_with_slack = True
447485
FLAGS.enable_tensorboard = True
448486
self._run_and_report_benchmark()
449487

@@ -636,6 +674,7 @@ def benchmark_8_gpu_tweaked(self):
636674
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_tweaked')
637675
FLAGS.batch_size = 256 * 8
638676
FLAGS.tf_gpu_thread_mode = 'gpu_private'
677+
FLAGS.data_prefetch_with_slack = True
639678
self._run_and_report_benchmark()
640679

641680
def benchmark_graph_8_gpu(self):

official/resnet/keras/keras_imagenet_main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ def run(flags_obj):
107107
# Execute flag override logic for better model performance
108108
if flags_obj.tf_gpu_thread_mode:
109109
keras_common.set_gpu_thread_mode_and_count(flags_obj)
110+
if flags_obj.data_prefetch_with_slack:
111+
keras_common.data_prefetch_with_slack()
110112

111113
dtype = flags_core.get_tf_dtype(flags_obj)
112114
if dtype == 'float16':

0 commit comments

Comments
 (0)