Skip to content

Commit 62184a9

Browse files
tfboydnnigania
authored andcommitted
[NCF] Add run_eagerly for ctl. (tensorflow#7229)
* Add run_eagerly for ctl. * fix test name and do not set "default".
1 parent 5834081 commit 62184a9

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

official/recommendation/ncf_keras_benchmark.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ def benchmark_1_gpu_ctl_early_stop(self):
181181
FLAGS.early_stopping = True
182182
self._run_and_report_benchmark()
183183

184+
def benchmark_1_gpu_ctl_run_eagerly_early_stop(self):
185+
self._setup()
186+
FLAGS.keras_use_ctl = True
187+
FLAGS.early_stopping = True
188+
FLAGS.run_eagerly = True
189+
self._run_and_report_benchmark()
190+
184191
def benchmark_xla_1_gpu_ctl_early_stop(self):
185192
self._setup()
186193
FLAGS.keras_use_ctl = True
@@ -203,7 +210,7 @@ def benchmark_2_gpus_ctl_early_stop(self):
203210
self._run_and_report_benchmark()
204211

205212
#############################################
206-
# Tests below with mlperf in the test name are of two types
213+
# Tests below with mlperf in the test name are of two types:
207214
# 1) 1 GPU tests are based on MLPerf 0.5 and the TensorFlow pulled submission.
208215
# 2) 8 GPU tests are based on MLPerf 0.5 and use NVIDIA's hyper parameters.
209216
#
@@ -254,6 +261,14 @@ def benchmark_1_gpu_ctl_mlperf_like(self):
254261
FLAGS.train_epochs = 7
255262
self._run_and_report_benchmark_mlperf_like()
256263

264+
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
265+
"""1 GPU using CTL with eager and distribution strategy."""
266+
self._setup()
267+
FLAGS.keras_use_ctl = True
268+
FLAGS.run_eagerly = True
269+
FLAGS.train_epochs = 7
270+
self._run_and_report_benchmark()
271+
257272
def benchmark_xla_1_gpu_ctl_mlperf_like(self):
258273
"""1 GPU using CTL with XLA."""
259274
self._setup()

official/recommendation/ncf_keras_main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def run_ncf(_):
285285
train_input_iterator = strategy.make_dataset_iterator(train_input_dataset)
286286
eval_input_iterator = strategy.make_dataset_iterator(eval_input_dataset)
287287

288-
@tf.function
289288
def train_step():
290289
"""Called once per step to train the model."""
291290
def step_fn(features):
@@ -310,7 +309,6 @@ def step_fn(features):
310309
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
311310
return mean_loss
312311

313-
@tf.function
314312
def eval_step():
315313
"""Called once per eval step to compute eval metrics."""
316314
def step_fn(features):
@@ -330,6 +328,10 @@ def step_fn(features):
330328
tf.distribute.ReduceOp.SUM, per_replica_hr_count, axis=None)
331329
return hr_sum, hr_count
332330

331+
if not FLAGS.run_eagerly:
332+
train_step = tf.function(train_step)
333+
eval_step = tf.function(eval_step)
334+
333335
time_callback.on_train_begin()
334336
for epoch in range(FLAGS.train_epochs):
335337
for cb in callbacks:

0 commit comments

Comments
 (0)