|
20 | 20 | # Python 2/3 compatibility
|
21 | 21 | __metaclass__ = type
|
22 | 22 |
|
23 |
| -""" |
24 |
| - No-op class for reusing CodeRunner architecture |
25 |
| -""" |
| 23 | +KERAS = "keras" |
| 24 | + |
26 | 25 |
|
27 | 26 | class ExperimentRunner:
|
| 27 | + """ |
| 28 | + No-op class for reusing CodeRunner architecture |
| 29 | + """ |
28 | 30 | def __init__(
|
29 | 31 | self,
|
30 | 32 | done=False,
|
@@ -81,6 +83,7 @@ def __init__(
|
81 | 83 | 2) capture_io: Should save stdout/stderror to log file and upload it to Hyperdash.
|
82 | 84 | """
|
83 | 85 | self.model_name = model_name
|
| 86 | + self.callbacks = Callbacks(self) |
84 | 87 | self._experiment_runner = ExperimentRunner()
|
85 | 88 | self.lock = Lock()
|
86 | 89 |
|
@@ -110,6 +113,7 @@ def __init__(
|
110 | 113 | self._logger,
|
111 | 114 | self._experiment_runner,
|
112 | 115 | )
|
| 116 | + |
113 | 117 | # Channel to update once experiment has finished running
|
114 | 118 | # Syncs with the seperate hyperdash messaging loop thread
|
115 | 119 | self.done_chan = Queue()
|
@@ -161,3 +165,65 @@ def end(self):
|
161 | 165 | """
|
162 | 166 | def log(self, string):
|
163 | 167 | self._logger.info(string)
|
| 168 | + |
| 169 | + |
| 170 | +class Callbacks: |
| 171 | + """Callbacks is a container class for 3rd-party library callbacks. |
| 172 | + |
| 173 | + An instance of Experiment is injected so that the callbacks can emit |
| 174 | + metrics/logs/parameters on behalf of an experiment. |
| 175 | + """ |
| 176 | + def __init__(self, exp): |
| 177 | + self._exp = exp |
| 178 | + self._callbacks = {} |
| 179 | + |
| 180 | + @property |
| 181 | + def keras(self): |
| 182 | + """ |
| 183 | + Returns an object that implements the Keras Callback interface. |
| 184 | +
|
| 185 | + This method initializes the Keras callback lazily to to prevent |
| 186 | + any possible import issues from affecting users who don't use it, |
| 187 | + as well as prevent it from importing Keras/tensorflow and all of |
| 188 | + their accompanying baggage unnecessarily in the case that they |
| 189 | + happened to be installed, but the user is not using them. |
| 190 | + """ |
| 191 | + cb = self._callbacks.get(KERAS) |
| 192 | + # Keras is not importable |
| 193 | + if cb == False: |
| 194 | + return None |
| 195 | + # If this is the first time, try and import Keras |
| 196 | + if not cb: |
| 197 | + # Check if Keras is installed and fallback gracefully |
| 198 | + try: |
| 199 | + from keras.callbacks import Callback as KerasCallback |
| 200 | + class _KerasCallback(KerasCallback): |
| 201 | + """_KerasCallback implement KerasCallback using an injected Experiment. |
| 202 | + |
| 203 | + # TODO: Decide if we want to handle the additional callbacks: |
| 204 | + # 1) on_epoch_begin |
| 205 | + # 2) on_batch_begin |
| 206 | + # 3) on_batch_end |
| 207 | + # 4) on_train_begin |
| 208 | + # 5) on_train_end |
| 209 | + """ |
| 210 | + def __init__(self, exp): |
| 211 | + super(_KerasCallback, self).__init__() |
| 212 | + self._exp = exp |
| 213 | + |
| 214 | + def on_epoch_end(self, epoch, logs={}): |
| 215 | + val_acc = logs.get("val_acc") |
| 216 | + val_loss = logs.get("val_loss") |
| 217 | + |
| 218 | + if val_acc is not None: |
| 219 | + self._exp.metric("val_acc", val_acc) |
| 220 | + if val_loss is not None: |
| 221 | + self._exp.metric("val_loss", val_loss) |
| 222 | + cb = _KerasCallback(self._exp) |
| 223 | + self._callbacks[KERAS] = cb |
| 224 | + return cb |
| 225 | + except ImportError: |
| 226 | + # Mark Keras as unimportable for future calls |
| 227 | + self._callbacks[KERAS] = False |
| 228 | + return None |
| 229 | + return cb |
0 commit comments