-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Fix Stateful Metrics in fit_generator with TensorBoard #10673
Conversation
Here is some simple code that shows the issue: import numpy as np
import keras
from keras.layers import Input, Dense
from keras.models import Model
# Dummy stateful metric
class BatchCounter(keras.layers.Layer):
def __init__(self, name="batch_counter", **kwargs):
super(BatchCounter, self).__init__(name=name, **kwargs)
self.stateful = True
self.batches = keras.backend.variable(value=0, dtype="float32")
def reset_states(self):
keras.backend.set_value(self.batches, 0)
def __call__(self, y_true, y_pred):
updates = [keras.backend.update_add(self.batches, keras.backend.variable(value=1, dtype="float32"))]
self.add_update(updates)
return self.batches
class DummyGenerator(object):
""" Dummy data generator. """
def run(self):
while True:
yield np.ones((10, 1)), np.zeros((10, 1))
train_gen = DummyGenerator()
val_gen = DummyGenerator()
# Dummy model
inputs = Input(shape=(1,))
outputs = Dense(1)(inputs)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss="mse", optimizer="adam", metrics=[BatchCounter()])
model.fit_generator(
train_gen.run(),
steps_per_epoch=5,
epochs=10,
validation_data=val_gen.run(),
validation_steps=5,
callbacks=[keras.callbacks.TensorBoard()]) |
Thank you for the PR. Please add a unit test. |
Ack. I added a stateful metric to the main TensorBoard tests which test many combinations. Let me know if you would rather spin this out into a new test (but that would require some code duplication). I would also like to add this to Tensorflow once the tests are up to standard. I know you are busy, but I would appreciate a look at more stateful metric related issues in tensorflow/tensorflow#20650 Thanks :) |
Tests were added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
tf.keras equivalent PR |
Summary
Currently
TensorBoard
does not work withfit_generator
because ofThis is a simple casting issue.
Related Issues
#10628
#10623
PR Overview