Skip to content

Commit acc6f6d

Browse files
authored
Combined imagenet and cifar-10 estimator tests (tensorflow#6672)
* Combined imagenet and cifar-10 benchmarks * Comments and epochs_between_evals. * Added tuned tests and cleaned up benchmark flags * Fix names. * Return results and add images/sec hook. * updated doc strings for return values. * 128 to 256 batch for FP16 test * added more doc strings to fix lint.
1 parent 67c403f commit acc6f6d

File tree

4 files changed

+398
-166
lines changed

4 files changed

+398
-166
lines changed
Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Executes Estimator benchmarks and accuracy tests."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
import time
22+
23+
from absl import flags
24+
from absl.testing import flagsaver
25+
import tensorflow as tf # pylint: disable=g-bad-import-order
26+
27+
from official.resnet import cifar10_main as cifar_main
28+
from official.resnet import imagenet_main
29+
from official.utils.logs import hooks
30+
31+
IMAGENET_DATA_DIR_NAME = 'imagenet'
32+
CIFAR_DATA_DIR_NAME = 'cifar-10-batches-bin'
33+
FLAGS = flags.FLAGS
34+
35+
36+
class EstimatorBenchmark(tf.test.Benchmark):
37+
"""Base class to hold methods common to test classes in the module.
38+
39+
Code under test for Estimator models (ResNet50 and 56) report mostly the
40+
same data and require the same FLAG setup.
41+
"""
42+
43+
local_flags = None
44+
45+
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
46+
if not output_dir:
47+
output_dir = '/tmp'
48+
self.output_dir = output_dir
49+
self.default_flags = default_flags or {}
50+
self.flag_methods = flag_methods or {}
51+
52+
def _get_model_dir(self, folder_name):
53+
"""Returns directory to store info, e.g. saved model and event log."""
54+
return os.path.join(self.output_dir, folder_name)
55+
56+
def _setup(self):
57+
"""Sets up and resets flags before each test."""
58+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
59+
if EstimatorBenchmark.local_flags is None:
60+
for flag_method in self.flag_methods:
61+
flag_method()
62+
# Loads flags to get defaults to then override. List cannot be empty.
63+
flags.FLAGS(['foo'])
64+
# Overrides flag values with defaults for the class of tests.
65+
for k, v in self.default_flags.items():
66+
setattr(FLAGS, k, v)
67+
saved_flag_values = flagsaver.save_flag_values()
68+
EstimatorBenchmark.local_flags = saved_flag_values
69+
else:
70+
flagsaver.restore_flag_values(EstimatorBenchmark.local_flags)
71+
72+
def _report_benchmark(self,
73+
stats,
74+
wall_time_sec,
75+
top_1_max=None,
76+
top_1_min=None):
77+
"""Report benchmark results by writing to local protobuf file.
78+
79+
Args:
80+
stats: dict returned from estimator models with known entries.
81+
wall_time_sec: the during of the benchmark execution in seconds
82+
top_1_max: highest passing level for top_1 accuracy.
83+
top_1_min: lowest passing level for top_1 accuracy.
84+
"""
85+
86+
examples_per_sec_hook = None
87+
for hook in stats['train_hooks']:
88+
if isinstance(hook, hooks.ExamplesPerSecondHook):
89+
examples_per_sec_hook = hook
90+
break
91+
92+
eval_results = stats['eval_results']
93+
metrics = []
94+
if 'accuracy' in eval_results:
95+
metrics.append({'name': 'accuracy_top_1',
96+
'value': eval_results['accuracy'].item(),
97+
'min_value': top_1_min,
98+
'max_value': top_1_max})
99+
if 'accuracy_top_5' in eval_results:
100+
metrics.append({'name': 'accuracy_top_5',
101+
'value': eval_results['accuracy_top_5'].item()})
102+
103+
if examples_per_sec_hook:
104+
exp_per_second_list = examples_per_sec_hook.current_examples_per_sec_list
105+
# ExamplesPerSecondHook skips the first 10 steps.
106+
exp_per_sec = sum(exp_per_second_list) / (len(exp_per_second_list))
107+
metrics.append({'name': 'exp_per_second',
108+
'value': exp_per_sec})
109+
self.report_benchmark(
110+
iters=eval_results['global_step'],
111+
wall_time=wall_time_sec,
112+
metrics=metrics)
113+
114+
115+
class Resnet50EstimatorAccuracy(EstimatorBenchmark):
116+
"""Benchmark accuracy tests for ResNet50 w/ Estimator."""
117+
118+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
119+
"""Benchmark accuracy tests for ResNet50 w/ Estimator.
120+
121+
Args:
122+
output_dir: directory where to output e.g. log files
123+
root_data_dir: directory under which to look for dataset
124+
**kwargs: arbitrary named arguments. This is needed to make the
125+
constructor forward compatible in case PerfZero provides more
126+
named arguments before updating the constructor.
127+
"""
128+
flag_methods = [imagenet_main.define_imagenet_flags]
129+
130+
self.data_dir = os.path.join(root_data_dir, IMAGENET_DATA_DIR_NAME)
131+
super(Resnet50EstimatorAccuracy, self).__init__(
132+
output_dir=output_dir, flag_methods=flag_methods)
133+
134+
def benchmark_graph_8_gpu(self):
135+
"""Test 8 GPUs graph mode."""
136+
self._setup()
137+
FLAGS.num_gpus = 8
138+
FLAGS.data_dir = self.data_dir
139+
FLAGS.batch_size = 128 * 8
140+
FLAGS.train_epochs = 90
141+
FLAGS.epochs_between_evals = 10
142+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
143+
FLAGS.dtype = 'fp32'
144+
FLAGS.hooks = ['ExamplesPerSecondHook']
145+
self._run_and_report_benchmark()
146+
147+
def benchmark_graph_fp16_8_gpu(self):
148+
"""Test FP16 8 GPUs graph mode."""
149+
self._setup()
150+
FLAGS.num_gpus = 8
151+
FLAGS.data_dir = self.data_dir
152+
FLAGS.batch_size = 256 * 8
153+
FLAGS.train_epochs = 90
154+
FLAGS.epochs_between_evals = 10
155+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
156+
FLAGS.dtype = 'fp16'
157+
FLAGS.hooks = ['ExamplesPerSecondHook']
158+
self._run_and_report_benchmark()
159+
160+
def _run_and_report_benchmark(self):
161+
start_time_sec = time.time()
162+
stats = imagenet_main.run_imagenet(flags.FLAGS)
163+
wall_time_sec = time.time() - start_time_sec
164+
self._report_benchmark(stats,
165+
wall_time_sec,
166+
top_1_min=0.762,
167+
top_1_max=0.766)
168+
169+
170+
class Resnet50EstimatorBenchmark(EstimatorBenchmark):
171+
"""Benchmarks for ResNet50 using Estimator."""
172+
local_flags = None
173+
174+
def __init__(self, output_dir=None, default_flags=None):
175+
flag_methods = [imagenet_main.define_imagenet_flags]
176+
177+
super(Resnet50EstimatorBenchmark, self).__init__(
178+
output_dir=output_dir,
179+
default_flags=default_flags,
180+
flag_methods=flag_methods)
181+
182+
def benchmark_graph_fp16_1_gpu(self):
183+
"""Benchmarks graph fp16 1 gpu."""
184+
self._setup()
185+
186+
FLAGS.num_gpus = 1
187+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_1_gpu')
188+
FLAGS.batch_size = 128
189+
FLAGS.dtype = 'fp16'
190+
FLAGS.hooks = ['ExamplesPerSecondHook']
191+
self._run_and_report_benchmark()
192+
193+
def benchmark_graph_fp16_1_gpu_tweaked(self):
194+
"""Benchmarks graph fp16 1 gpu tweaked."""
195+
self._setup()
196+
197+
FLAGS.num_gpus = 1
198+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
199+
FLAGS.intra_op_parallelism_threads = 1
200+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_1_gpu_tweaked')
201+
FLAGS.batch_size = 256
202+
FLAGS.dtype = 'fp16'
203+
FLAGS.hooks = ['ExamplesPerSecondHook']
204+
self._run_and_report_benchmark()
205+
206+
def benchmark_graph_1_gpu(self):
207+
"""Benchmarks graph 1 gpu."""
208+
self._setup()
209+
210+
FLAGS.num_gpus = 1
211+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
212+
FLAGS.batch_size = 128
213+
FLAGS.dtype = 'fp32'
214+
FLAGS.hooks = ['ExamplesPerSecondHook']
215+
self._run_and_report_benchmark()
216+
217+
def benchmark_graph_8_gpu(self):
218+
"""Benchmarks graph 8 gpus."""
219+
self._setup()
220+
221+
FLAGS.num_gpus = 8
222+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_8_gpu')
223+
FLAGS.batch_size = 128*8
224+
FLAGS.dtype = 'fp32'
225+
FLAGS.hooks = ['ExamplesPerSecondHook']
226+
self._run_and_report_benchmark()
227+
228+
def benchmark_graph_fp16_8_gpu(self):
229+
"""Benchmarks graph fp16 8 gpus."""
230+
self._setup()
231+
232+
FLAGS.num_gpus = 8
233+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu')
234+
FLAGS.batch_size = 256*8
235+
FLAGS.dtype = 'fp16'
236+
FLAGS.hooks = ['ExamplesPerSecondHook']
237+
self._run_and_report_benchmark()
238+
239+
def benchmark_graph_fp16_8_gpu_tweaked(self):
240+
"""Benchmarks graph fp16 8 gpus tweaked."""
241+
self._setup()
242+
243+
FLAGS.num_gpus = 8
244+
FLAGS.tf_gpu_thread_mode = 'gpu_private'
245+
FLAGS.intra_op_parallelism_threads = 1
246+
FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_8_gpu_tweaked')
247+
FLAGS.batch_size = 256*8
248+
FLAGS.dtype = 'fp16'
249+
FLAGS.hooks = ['ExamplesPerSecondHook']
250+
self._run_and_report_benchmark()
251+
252+
def _run_and_report_benchmark(self):
253+
start_time_sec = time.time()
254+
stats = imagenet_main.run_imagenet(FLAGS)
255+
wall_time_sec = time.time() - start_time_sec
256+
print(stats)
257+
# Remove values to skip triggering accuracy check.
258+
del stats['eval_results']['accuracy']
259+
del stats['eval_results']['accuracy_top_5']
260+
261+
self._report_benchmark(stats,
262+
wall_time_sec)
263+
264+
265+
class Resnet50EstimatorBenchmarkSynth(Resnet50EstimatorBenchmark):
266+
"""Resnet50 synthetic benchmark tests."""
267+
268+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
269+
def_flags = {}
270+
def_flags['use_synthetic_data'] = True
271+
def_flags['max_train_steps'] = 110
272+
def_flags['train_epochs'] = 1
273+
274+
super(Resnet50EstimatorBenchmarkSynth, self).__init__(
275+
output_dir=output_dir, default_flags=def_flags)
276+
277+
278+
class Resnet50EstimatorBenchmarkReal(Resnet50EstimatorBenchmark):
279+
"""Resnet50 real data benchmark tests."""
280+
281+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
282+
def_flags = {}
283+
def_flags['data_dir'] = os.path.join(root_data_dir, IMAGENET_DATA_DIR_NAME)
284+
def_flags['max_train_steps'] = 110
285+
def_flags['train_epochs'] = 1
286+
287+
super(Resnet50EstimatorBenchmarkReal, self).__init__(
288+
output_dir=output_dir, default_flags=def_flags)
289+
290+
291+
class Resnet56EstimatorAccuracy(EstimatorBenchmark):
292+
"""Accuracy tests for Estimator ResNet56."""
293+
294+
local_flags = None
295+
296+
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
297+
"""A benchmark class.
298+
299+
Args:
300+
output_dir: directory where to output e.g. log files
301+
root_data_dir: directory under which to look for dataset
302+
**kwargs: arbitrary named arguments. This is needed to make the
303+
constructor forward compatible in case PerfZero provides more
304+
named arguments before updating the constructor.
305+
"""
306+
flag_methods = [cifar_main.define_cifar_flags]
307+
308+
self.data_dir = os.path.join(root_data_dir, CIFAR_DATA_DIR_NAME)
309+
super(Resnet56EstimatorAccuracy, self).__init__(
310+
output_dir=output_dir, flag_methods=flag_methods)
311+
312+
def benchmark_graph_1_gpu(self):
313+
"""Test layers model with Estimator and distribution strategies."""
314+
self._setup()
315+
flags.FLAGS.num_gpus = 1
316+
flags.FLAGS.data_dir = self.data_dir
317+
flags.FLAGS.batch_size = 128
318+
flags.FLAGS.train_epochs = 182
319+
flags.FLAGS.model_dir = self._get_model_dir('benchmark_graph_1_gpu')
320+
flags.FLAGS.resnet_size = 56
321+
flags.FLAGS.dtype = 'fp32'
322+
flags.FLAGS.hooks = ['ExamplesPerSecondHook']
323+
self._run_and_report_benchmark()
324+
325+
def benchmark_graph_fp16_1_gpu(self):
326+
"""Test layers FP16 model with Estimator and distribution strategies."""
327+
self._setup()
328+
flags.FLAGS.num_gpus = 1
329+
flags.FLAGS.data_dir = self.data_dir
330+
flags.FLAGS.batch_size = 128
331+
flags.FLAGS.train_epochs = 182
332+
flags.FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_1_gpu')
333+
flags.FLAGS.resnet_size = 56
334+
flags.FLAGS.dtype = 'fp16'
335+
flags.FLAGS.hooks = ['ExamplesPerSecondHook']
336+
self._run_and_report_benchmark()
337+
338+
def benchmark_graph_2_gpu(self):
339+
"""Test layers model with Estimator and dist_strat. 2 GPUs."""
340+
self._setup()
341+
flags.FLAGS.num_gpus = 2
342+
flags.FLAGS.data_dir = self.data_dir
343+
flags.FLAGS.batch_size = 128
344+
flags.FLAGS.train_epochs = 182
345+
flags.FLAGS.model_dir = self._get_model_dir('benchmark_graph_2_gpu')
346+
flags.FLAGS.resnet_size = 56
347+
flags.FLAGS.dtype = 'fp32'
348+
flags.FLAGS.hooks = ['ExamplesPerSecondHook']
349+
self._run_and_report_benchmark()
350+
351+
def benchmark_graph_fp16_2_gpu(self):
352+
"""Test layers FP16 model with Estimator and dist_strat. 2 GPUs."""
353+
self._setup()
354+
flags.FLAGS.num_gpus = 2
355+
flags.FLAGS.data_dir = self.data_dir
356+
flags.FLAGS.batch_size = 128
357+
flags.FLAGS.train_epochs = 182
358+
flags.FLAGS.model_dir = self._get_model_dir('benchmark_graph_fp16_2_gpu')
359+
flags.FLAGS.resnet_size = 56
360+
flags.FLAGS.dtype = 'fp16'
361+
flags.FLAGS.hooks = ['ExamplesPerSecondHook']
362+
self._run_and_report_benchmark()
363+
364+
def unit_test(self):
365+
"""A lightweight test that can finish quickly."""
366+
self._setup()
367+
flags.FLAGS.num_gpus = 1
368+
flags.FLAGS.data_dir = self.data_dir
369+
flags.FLAGS.batch_size = 128
370+
flags.FLAGS.train_epochs = 1
371+
flags.FLAGS.model_dir = self._get_model_dir('unit_test')
372+
flags.FLAGS.resnet_size = 8
373+
flags.FLAGS.dtype = 'fp32'
374+
flags.FLAGS.hooks = ['ExamplesPerSecondHook']
375+
self._run_and_report_benchmark()
376+
377+
def _run_and_report_benchmark(self):
378+
"""Executes benchmark and reports result."""
379+
start_time_sec = time.time()
380+
stats = cifar_main.run_cifar(flags.FLAGS)
381+
wall_time_sec = time.time() - start_time_sec
382+
383+
self._report_benchmark(stats,
384+
wall_time_sec,
385+
top_1_min=0.926,
386+
top_1_max=0.938)

0 commit comments

Comments
 (0)