Skip to content

Commit 18bdc98

Browse files
Merge pull request #113 from hyperdashio/ra/support-keras-callback
Add support for keras callback
2 parents 3a2c538 + 62b188e commit 18bdc98

File tree

4 files changed

+104
-4
lines changed

4 files changed

+104
-4
lines changed

hyperdash/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.13.1
1+
0.14.1

hyperdash/experiment.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
# Python 2/3 compatibility
2121
__metaclass__ = type
2222

23-
"""
24-
No-op class for reusing CodeRunner architecture
25-
"""
23+
KERAS = "keras"
24+
2625

2726
class ExperimentRunner:
27+
"""
28+
No-op class for reusing CodeRunner architecture
29+
"""
2830
def __init__(
2931
self,
3032
done=False,
@@ -81,6 +83,7 @@ def __init__(
8183
2) capture_io: Should save stdout/stderror to log file and upload it to Hyperdash.
8284
"""
8385
self.model_name = model_name
86+
self.callbacks = Callbacks(self)
8487
self._experiment_runner = ExperimentRunner()
8588
self.lock = Lock()
8689

@@ -110,6 +113,7 @@ def __init__(
110113
self._logger,
111114
self._experiment_runner,
112115
)
116+
113117
# Channel to update once experiment has finished running
114118
# Syncs with the seperate hyperdash messaging loop thread
115119
self.done_chan = Queue()
@@ -161,3 +165,65 @@ def end(self):
161165
"""
162166
def log(self, string):
163167
self._logger.info(string)
168+
169+
170+
class Callbacks:
171+
"""Callbacks is a container class for 3rd-party library callbacks.
172+
173+
An instance of Experiment is injected so that the callbacks can emit
174+
metrics/logs/parameters on behalf of an experiment.
175+
"""
176+
def __init__(self, exp):
177+
self._exp = exp
178+
self._callbacks = {}
179+
180+
@property
181+
def keras(self):
182+
"""
183+
Returns an object that implements the Keras Callback interface.
184+
185+
This method initializes the Keras callback lazily to to prevent
186+
any possible import issues from affecting users who don't use it,
187+
as well as prevent it from importing Keras/tensorflow and all of
188+
their accompanying baggage unnecessarily in the case that they
189+
happened to be installed, but the user is not using them.
190+
"""
191+
cb = self._callbacks.get(KERAS)
192+
# Keras is not importable
193+
if cb == False:
194+
return None
195+
# If this is the first time, try and import Keras
196+
if not cb:
197+
# Check if Keras is installed and fallback gracefully
198+
try:
199+
from keras.callbacks import Callback as KerasCallback
200+
class _KerasCallback(KerasCallback):
201+
"""_KerasCallback implement KerasCallback using an injected Experiment.
202+
203+
# TODO: Decide if we want to handle the additional callbacks:
204+
# 1) on_epoch_begin
205+
# 2) on_batch_begin
206+
# 3) on_batch_end
207+
# 4) on_train_begin
208+
# 5) on_train_end
209+
"""
210+
def __init__(self, exp):
211+
super(_KerasCallback, self).__init__()
212+
self._exp = exp
213+
214+
def on_epoch_end(self, epoch, logs={}):
215+
val_acc = logs.get("val_acc")
216+
val_loss = logs.get("val_loss")
217+
218+
if val_acc is not None:
219+
self._exp.metric("val_acc", val_acc)
220+
if val_loss is not None:
221+
self._exp.metric("val_loss", val_loss)
222+
cb = _KerasCallback(self._exp)
223+
self._callbacks[KERAS] = cb
224+
return cb
225+
except ImportError:
226+
# Mark Keras as unimportable for future calls
227+
self._callbacks[KERAS] = False
228+
return None
229+
return cb

requirements_dev.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ jupyter==1.0.0
77
python-slugify==1.2.4
88
twine==1.9.1
99
numpy==1.13.3
10+
keras==2.1.1
11+
# Required for Keras
12+
np_utils==0.5.3.4
13+
# Required for Keras
14+
tensorflow==1.4.0

tests/test_sdk.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,35 @@ def test_experiment(self):
402402
assert_in(log, data)
403403
os.remove(latest_log_file)
404404

405+
def test_experiment_keras_callback(self):
406+
with patch("sys.stdout", new=StringIO()) as faked_out:
407+
exp = Experiment("MNIST")
408+
keras_cb = exp.callbacks.keras
409+
keras_cb.on_epoch_end(0, {"val_acc": 1, "val_loss": 2})
410+
# Sleep 1 second due to client sampling
411+
time.sleep(1)
412+
keras_cb.on_epoch_end(1, {"val_acc": 3, "val_loss": 4})
413+
exp.end()
414+
415+
# Test metrics match what is expected
416+
metrics_messages = []
417+
for msg in server_sdk_messages:
418+
payload = msg["payload"]
419+
if "name" in payload:
420+
metrics_messages.append(payload)
421+
expect_metrics = [
422+
{"is_internal": False, "name": "val_acc", "value": 1},
423+
{"is_internal": False, "name": "val_loss", "value": 2},
424+
{"is_internal": False, "name": "val_acc", "value": 3},
425+
{"is_internal": False, "name": "val_loss", "value": 4},
426+
]
427+
assert len(expect_metrics) == len(metrics_messages)
428+
for i, message in enumerate(metrics_messages):
429+
assert message == expect_metrics[i]
430+
431+
captured_out = faked_out.getvalue()
432+
assert "error" not in captured_out
433+
405434
def test_experiment_handles_numpy_numbers(self):
406435
nums_to_test = [
407436
("int_", np.int_()),

0 commit comments

Comments
 (0)