Skip to content

Add support for keras callback #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hyperdash/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.13.1
0.14.1
72 changes: 69 additions & 3 deletions hyperdash/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
# Python 2/3 compatibility
__metaclass__ = type

"""
No-op class for reusing CodeRunner architecture
"""
KERAS = "keras"


class ExperimentRunner:
"""
No-op class for reusing CodeRunner architecture
"""
def __init__(
self,
done=False,
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
2) capture_io: Should save stdout/stderror to log file and upload it to Hyperdash.
"""
self.model_name = model_name
self.callbacks = Callbacks(self)
self._experiment_runner = ExperimentRunner()
self.lock = Lock()

Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
self._logger,
self._experiment_runner,
)

# Channel to update once experiment has finished running
# Syncs with the seperate hyperdash messaging loop thread
self.done_chan = Queue()
Expand Down Expand Up @@ -161,3 +165,65 @@ def end(self):
"""
def log(self, string):
self._logger.info(string)


class Callbacks:
"""Callbacks is a container class for 3rd-party library callbacks.

An instance of Experiment is injected so that the callbacks can emit
metrics/logs/parameters on behalf of an experiment.
"""
def __init__(self, exp):
self._exp = exp
self._callbacks = {}

@property
def keras(self):
"""
Returns an object that implements the Keras Callback interface.

This method initializes the Keras callback lazily to to prevent
any possible import issues from affecting users who don't use it,
as well as prevent it from importing Keras/tensorflow and all of
their accompanying baggage unnecessarily in the case that they
happened to be installed, but the user is not using them.
"""
cb = self._callbacks.get(KERAS)
# Keras is not importable
if cb == False:
return None
# If this is the first time, try and import Keras
if not cb:
# Check if Keras is installed and fallback gracefully
try:
from keras.callbacks import Callback as KerasCallback
class _KerasCallback(KerasCallback):
"""_KerasCallback implement KerasCallback using an injected Experiment.

# TODO: Decide if we want to handle the additional callbacks:
# 1) on_epoch_begin
# 2) on_batch_begin
# 3) on_batch_end
# 4) on_train_begin
# 5) on_train_end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think on_train_end will be key for knowing when the run ends, true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah we know that already cause the user's program exits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or they call exp.end()

"""
def __init__(self, exp):
super(_KerasCallback, self).__init__()
self._exp = exp

def on_epoch_end(self, epoch, logs={}):
val_acc = logs.get("val_acc")
val_loss = logs.get("val_loss")

if val_acc is not None:
self._exp.metric("val_acc", val_acc)
if val_loss is not None:
self._exp.metric("val_loss", val_loss)
cb = _KerasCallback(self._exp)
self._callbacks[KERAS] = cb
return cb
except ImportError:
# Mark Keras as unimportable for future calls
self._callbacks[KERAS] = False
return None
return cb
5 changes: 5 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ jupyter==1.0.0
python-slugify==1.2.4
twine==1.9.1
numpy==1.13.3
keras==2.1.1
# Required for Keras
np_utils==0.5.3.4
# Required for Keras
tensorflow==1.4.0
29 changes: 29 additions & 0 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,35 @@ def test_experiment(self):
assert_in(log, data)
os.remove(latest_log_file)

def test_experiment_keras_callback(self):
with patch("sys.stdout", new=StringIO()) as faked_out:
exp = Experiment("MNIST")
keras_cb = exp.callbacks.keras
keras_cb.on_epoch_end(0, {"val_acc": 1, "val_loss": 2})
# Sleep 1 second due to client sampling
time.sleep(1)
keras_cb.on_epoch_end(1, {"val_acc": 3, "val_loss": 4})
exp.end()

# Test metrics match what is expected
metrics_messages = []
for msg in server_sdk_messages:
payload = msg["payload"]
if "name" in payload:
metrics_messages.append(payload)
expect_metrics = [
{"is_internal": False, "name": "val_acc", "value": 1},
{"is_internal": False, "name": "val_loss", "value": 2},
{"is_internal": False, "name": "val_acc", "value": 3},
{"is_internal": False, "name": "val_loss", "value": 4},
]
assert len(expect_metrics) == len(metrics_messages)
for i, message in enumerate(metrics_messages):
assert message == expect_metrics[i]

captured_out = faked_out.getvalue()
assert "error" not in captured_out

def test_experiment_handles_numpy_numbers(self):
nums_to_test = [
("int_", np.int_()),
Expand Down