Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix Stateful Metrics in fit_generator with TensorBoard #10673

Merged
merged 3 commits into from
Jul 23, 2018
Merged

[Bug] Fix Stateful Metrics in fit_generator with TensorBoard #10673

merged 3 commits into from
Jul 23, 2018

Conversation

briannemsick
Copy link
Contributor

Summary

Currently TensorBoard does not work with fit_generator because of

line 942, in on_epoch_end
    summary_value.simple_value = value.item()
AttributeError: 'float' object has no attribute 'item'

This is a simple casting issue.

Related Issues

#10628
#10623

PR Overview

  • [n] This PR requires new unit tests [y/n] (make sure tests are included)
  • [n] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date)
  • [y] This PR is backwards compatible [y/n]
  • [n] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)

@briannemsick
Copy link
Contributor Author

Here is some simple code that shows the issue:

import numpy as np

import keras
from keras.layers import Input, Dense
from keras.models import Model

# Dummy stateful metric
class BatchCounter(keras.layers.Layer):

        def __init__(self, name="batch_counter", **kwargs):
            super(BatchCounter, self).__init__(name=name, **kwargs)
            self.stateful = True
            self.batches = keras.backend.variable(value=0, dtype="float32")

        def reset_states(self):
            keras.backend.set_value(self.batches, 0)

        def __call__(self, y_true, y_pred):
            updates = [keras.backend.update_add(self.batches, keras.backend.variable(value=1, dtype="float32"))]
            self.add_update(updates)
            return self.batches

class DummyGenerator(object):
    """ Dummy data generator. """

    def run(self):
        while True:
            yield np.ones((10, 1)), np.zeros((10, 1))

train_gen = DummyGenerator()
val_gen = DummyGenerator()

# Dummy model
inputs = Input(shape=(1,))
outputs = Dense(1)(inputs)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss="mse", optimizer="adam", metrics=[BatchCounter()])

model.fit_generator(
    train_gen.run(),
    steps_per_epoch=5,
    epochs=10,
    validation_data=val_gen.run(),
    validation_steps=5,
    callbacks=[keras.callbacks.TensorBoard()])

@fchollet
Copy link
Collaborator

Thank you for the PR. Please add a unit test.

@briannemsick
Copy link
Contributor Author

briannemsick commented Jul 16, 2018

Ack.

I added a stateful metric to the main TensorBoard tests which test many combinations. Let me know if you would rather spin this out into a new test (but that would require some code duplication).

I would also like to add this to Tensorflow once the tests are up to standard.

I know you are busy, but I would appreciate a look at more stateful metric related issues in tf.keras:

tensorflow/tensorflow#20650
tensorflow/tensorflow#20529

Thanks :)

@briannemsick
Copy link
Contributor Author

Tests were added.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@brge17
Copy link
Contributor

brge17 commented Jul 23, 2018

tensorflow/tensorflow#21071

tf.keras equivalent PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants