Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write to TensorBoard every x samples. #11152

Merged
merged 17 commits into from
Oct 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
32f7df2
Working on improving tensor flow callbacks
sameermanek May 18, 2017
7ceefaa
Adding batch level TensorBoard logging (implementing the `on_batch_en…
sameermanek May 19, 2017
186a4fb
Interim commit -- added notes.
sameermanek Aug 12, 2017
1390f21
Updating to master
sameermanek Aug 12, 2017
a7b3587
Merge branch 'master' of https://github.com/fchollet/keras into tenso…
sameermanek Aug 12, 2017
138b5c5
Corrected stylistic issues -- brought to compliance w/ PEP8
sameermanek Aug 13, 2017
3ba7cb3
Merge branch 'tensorboard-callback-modifications' of https://github.c…
gabrieldemarmiesse Sep 16, 2018
463b704
Added the missing argument in the test suite.
gabrieldemarmiesse Sep 16, 2018
f201418
Added the possibility to choose how frequently tensorboard should log
gabrieldemarmiesse Sep 16, 2018
284d903
Fixed the issue of the validation data not being displayed.
gabrieldemarmiesse Sep 16, 2018
22573c7
Fixed the issue about the callback not remembering when was the last
gabrieldemarmiesse Sep 16, 2018
f4efd69
Removed the error check.
gabrieldemarmiesse Sep 16, 2018
06cc111
Used update_freq instead of write_step.
gabrieldemarmiesse Sep 17, 2018
a1619ae
Forgot to change the constructor call.
gabrieldemarmiesse Sep 17, 2018
cf4c647
Merge branch 'master' into batch_tensorboard
gabrieldemarmiesse Sep 19, 2018
358abe7
Merge branch 'master' into batch_tensorboard
gabrieldemarmiesse Sep 23, 2018
072f923
Merge branch 'batch_tensorboard' of github.com:gabrieldemarmiesse/ker…
gabrieldemarmiesse Sep 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,12 @@ class TensorBoard(Callback):
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings]
(https://www.tensorflow.org/programmers_guide/embedding).
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes
the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `10000`,
the callback will write the metrics and losses to TensorBoard every
10000 samples. Note that writing too frequently to TensorBoard
can slow down your training.
"""

def __init__(self, log_dir='./logs',
Expand All @@ -730,7 +736,8 @@ def __init__(self, log_dir='./logs',
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None):
embeddings_data=None,
update_freq='epoch'):
super(TensorBoard, self).__init__()
global tf, projector
try:
Expand Down Expand Up @@ -769,6 +776,13 @@ def __init__(self, log_dir='./logs',
self.embeddings_metadata = embeddings_metadata or {}
self.batch_size = batch_size
self.embeddings_data = embeddings_data
if update_freq == 'batch':
# It is the same as writing as frequently as possible.
self.update_freq = 1
else:
self.update_freq = update_freq
self.samples_seen = 0
self.samples_seen_at_last_write = 0

def set_model(self, model):
self.model = model
Expand Down Expand Up @@ -968,6 +982,13 @@ def on_epoch_end(self, epoch, logs=None):

i += self.batch_size

if self.update_freq == 'epoch':
index = epoch
else:
index = self.samples_seen
self._write_logs(logs, index)

def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
Expand All @@ -978,12 +999,20 @@ def on_epoch_end(self, epoch, logs=None):
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, epoch)
self.writer.add_summary(summary, index)
self.writer.flush()

def on_train_end(self, _):
self.writer.close()

def on_batch_end(self, batch, logs=None):
if self.update_freq != 'epoch':
self.samples_seen += logs['size']
samples_seen_since = self.samples_seen - self.samples_seen_at_last_write
if samples_seen_since >= self.update_freq:
self._write_logs(logs, self.samples_seen)
self.samples_seen_at_last_write = self.samples_seen


class ReduceLROnPlateau(Callback):
"""Reduce learning rate when a metric has stopped improving.
Expand Down
6 changes: 4 additions & 2 deletions tests/keras/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ def make_model():
assert not tmpdir.listdir()


def test_TensorBoard(tmpdir):
@pytest.mark.parametrize('update_freq', ['batch', 'epoch', 9])
def test_TensorBoard(tmpdir, update_freq):
np.random.seed(np.random.randint(1, 1e7))
filepath = str(tmpdir / 'logs')

Expand Down Expand Up @@ -588,7 +589,8 @@ def callbacks_factory(histogram_freq, embeddings_freq=1):
embeddings_freq=embeddings_freq,
embeddings_layer_names=['dense_1'],
embeddings_data=X_test,
batch_size=5)]
batch_size=5,
update_freq=update_freq)]

# fit without validation data
model.fit(X_train, y_train, batch_size=batch_size,
Expand Down