19
19
from typing import Any , Dict , List , Union
20
20
21
21
import pandas as pd
22
- import pendulum
23
22
from deepdiff import DeepDiff
24
23
from flatten_json import flatten
25
24
@@ -51,14 +50,22 @@ def __init__(
51
50
last_models: int, optional
52
51
For each line constraints the number of models back from the last one to view
53
52
"""
53
+
54
+ try :
55
+ import plotly
56
+ except ModuleNotFoundError :
57
+ self ._raise_cannot_import_plotly ()
58
+ else :
59
+ from plotly import express as px
60
+ from plotly import graph_objects as go
61
+
62
+ self ._px = px
63
+ self ._go = go
64
+
54
65
self ._container = container
55
66
self ._last_lines = last_lines
56
67
self ._last_models = last_models
57
68
58
- self ._reload ()
59
- self ._make_table ()
60
-
61
- def _reload (self ) -> None :
62
69
repo = self ._container
63
70
if isinstance (self ._container , ModelLine ):
64
71
repo = SingleLineRepo (self ._container )
@@ -72,9 +79,16 @@ def _reload(self) -> None:
72
79
for repo in repos :
73
80
repo .reload ()
74
81
82
+ self ._edges = dict ()
75
83
self ._repo = repo
76
84
self ._repos = {repo .get_root (): repo for repo in repos }
77
85
86
+ self ._make_table ()
87
+
88
+ def _update (self ) -> None :
89
+ self ._repo .reload ()
90
+ self ._make_table ()
91
+
78
92
def _get_last_updated_lines (self , line_names : List [str ]) -> List [str ]:
79
93
valid_lines = []
80
94
updated_at = []
@@ -113,18 +127,17 @@ def _make_table(self) -> None:
113
127
114
128
last_models = self ._last_models if self ._last_models is not None else 0
115
129
for i in range (len (line ))[- last_models :]:
116
- new_meta = {"line" : line_root , "model" : i }
130
+ line_name = os .path .split (line_root )[- 1 ]
131
+ new_meta = {"line" : line_name , "model" : i }
117
132
try :
118
- # TODO: to take only first is not good...
119
133
meta = view [i ][0 ]
134
+ new_meta .update (flatten (meta ))
120
135
except IndexError :
121
- meta = {}
122
-
123
- new_meta .update (flatten (meta ))
136
+ pass
124
137
metas .append (new_meta )
125
138
126
139
p = {
127
- "line" : line_root ,
140
+ "line" : line_name ,
128
141
}
129
142
if "params" in meta :
130
143
if len (meta ["params" ]) > 0 :
@@ -136,6 +149,23 @@ def _make_table(self) -> None:
136
149
if "saved_at" in self ._table :
137
150
self ._table = self ._table .sort_values ("saved_at" )
138
151
152
+ # turn time into evenly spaced intervals
153
+ time = [i for i in range (len (self ._table ))]
154
+ lines = self ._table ["line" ].unique ()
155
+
156
+ cmap = self ._px .colors .qualitative .Plotly
157
+ cmap_len = len (self ._px .colors .qualitative .Plotly )
158
+ line_cols = {line : cmap [i % cmap_len ] for i , line in enumerate (lines )}
159
+
160
+ self ._table ["time" ] = time
161
+ self ._table ["color" ] = [line_cols [line ] for line in self ._table ["line" ]]
162
+ self ._table = self ._table .fillna ("" )
163
+
164
+ columns2fill = [
165
+ col for col in self ._table .columns if not col .startswith ("metrics_" )
166
+ ]
167
+ self ._table = self ._table .fillna ({name : "" for name in columns2fill })
168
+
139
169
@staticmethod
140
170
def _diff (p1 : Dict [Any , Any ], params : Dict [Any , Any ]) -> List :
141
171
diff = [DeepDiff (p1 , p2 ) for p2 in params ]
@@ -169,6 +199,23 @@ def _preprocess_metric(self, metric):
169
199
170
200
return metric
171
201
202
+ def _connect_points (self , line : str , metric : str , fig : Any ):
203
+ edges = [0 ]
204
+ params = [p for p in self ._params if p ["line" ] == line ]
205
+ for i in range (1 , len (params )):
206
+ diff = self ._diff (params [i ], params [:i ])
207
+ edges .append (self ._specific_argmin (diff , i ))
208
+
209
+ xs = []
210
+ ys = []
211
+ t = self ._table .loc [self ._table ["line" ] == line ]
212
+ for i , e in enumerate (edges ):
213
+ xs += [t ["time" ].iloc [i ], t ["time" ].iloc [e ], None ]
214
+ ys += [t [metric ].iloc [i ], t [metric ].iloc [e ], None ]
215
+
216
+ self ._edges [line ] = {"edges" : (xs , ys ), "len" : len (t )}
217
+ return xs , ys
218
+
172
219
def plot (self , metric : str , show : bool = False ) -> Any :
173
220
"""
174
221
Plots training history of model versions using plotly.
@@ -180,109 +227,76 @@ def plot(self, metric: str, show: bool = False) -> Any:
180
227
show: bool, optional
181
228
Whether to return and show or just return figure
182
229
"""
183
- try :
184
- import plotly
185
- except ModuleNotFoundError :
186
- self ._raise_cannot_import_plotly ()
187
- else :
188
- from plotly import express as px
189
- from plotly import graph_objects as go
190
-
191
- metric = self ._preprocess_metric (metric )
192
-
193
- # turn time into evenly spaced intervals
194
- time = [i for i in range (len (self ._table ))]
195
- lines = self ._table ["line" ].unique ()
196
-
197
- cmap = px .colors .qualitative .Plotly
198
- cmap_len = len (px .colors .qualitative .Plotly )
199
- line_cols = {line : cmap [i % cmap_len ] for i , line in enumerate (lines )}
200
-
201
- self ._table ["time" ] = time
202
- self ._table ["color" ] = [line_cols [line ] for line in self ._table ["line" ]]
203
- table = self ._table .fillna ("" )
204
-
205
- columns2fill = [
206
- col for col in self ._table .columns if not col .startswith ("metrics_" )
207
- ]
208
- table = self ._table .fillna ({name : "" for name in columns2fill })
209
230
210
231
# plot each model against metric
211
232
# with all metadata on hover
233
+ metric = self ._preprocess_metric (metric )
212
234
213
235
hover_cols = [name for name in pd .DataFrame (self ._params ).columns ]
214
- if "saved_at" in table .columns :
236
+ if "saved_at" in self . _table .columns :
215
237
hover_cols = ["saved_at" ] + hover_cols
216
238
hover_cols = ["model" ] + hover_cols
217
- fig = px .scatter (table , x = "time" , y = metric , hover_data = hover_cols , color = "line" )
218
-
219
- # determine connections between models
220
- # plot each one with respected color
239
+ fig = self ._px .scatter (self ._table , x = "time" , y = metric , hover_data = hover_cols , color = "line" )
240
+ lines = self ._table ["line" ].unique ()
221
241
222
242
for line in lines :
223
- params = [p for p in self ._params if p ["line" ] == line ]
224
- edges = []
225
- for i in range (len (params )):
226
- if i == 0 :
227
- edges .append (0 )
228
- continue
229
- else :
230
- diff = self ._diff (params [i ], params [:i ])
231
- edges .append (self ._specific_argmin (diff , i ))
232
-
233
- xs = []
234
- ys = []
235
- t = table .loc [table ["line" ] == line ]
236
- for i , e in enumerate (edges ):
237
- xs += [t ["time" ].iloc [i ], t ["time" ].iloc [e ], None ]
238
- ys += [t [metric ].iloc [i ], t [metric ].iloc [e ], None ]
243
+ t = self ._table .loc [self ._table .line == line ]
244
+ self ._connect_points (line , metric , fig )
239
245
246
+ xs , ys = self ._edges [line ]["edges" ]
240
247
fig .add_trace (
241
- go .Scatter (
248
+ self . _go .Scatter (
242
249
x = xs ,
243
250
y = ys ,
244
251
mode = "lines" ,
245
- marker = {"color" : t ["color" ].iloc [0 ]},
246
252
name = line ,
247
253
hoverinfo = "none" ,
254
+ marker_color = t ["color" ].iloc [0 ]
248
255
)
249
256
)
250
257
251
- # Create human-readable ticks
252
- now = pendulum .now (tz = "UTC" )
253
- time_text = table ["saved_at" ].apply (
254
- lambda t : t if t == "" else pendulum .parse (t ).diff_for_humans (now )
255
- )
256
-
257
- fig .update_layout (
258
- hovermode = "x" ,
259
- xaxis = dict (
260
- tickmode = "array" ,
261
- tickvals = [i for i in range (len (time ))],
262
- ticktext = time_text ,
263
- ),
264
- )
265
258
if show :
266
259
fig .show ()
267
260
268
261
return fig
269
262
270
- def _layout (self , metric ):
263
+ def _update_plot (self , metric : str ) -> Any :
264
+ metric = self ._preprocess_metric (metric )
265
+
266
+ hover_cols = [name for name in pd .DataFrame (self ._params ).columns ]
267
+ if "saved_at" in self ._table .columns :
268
+ hover_cols = ["saved_at" ] + hover_cols
269
+ hover_cols = ["model" ] + hover_cols
270
+ fig = self ._px .scatter (self ._table , x = "time" , y = metric , hover_data = hover_cols , color = "line" )
271
+
272
+ for line in sorted (self ._table .line .unique ()):
273
+ t = self ._table .loc [self ._table .line == line ]
274
+ if line in self ._edges and len (t ) == self ._edges [line ]["len" ]:
275
+ xs , ys = self ._edges [line ]["edges" ]
276
+ else :
277
+ xs , ys = self ._connect_points (line , metric , fig )
278
+ fig .add_trace (
279
+ self ._go .Scatter (
280
+ x = xs ,
281
+ y = ys ,
282
+ mode = "lines" ,
283
+ name = line ,
284
+ hoverinfo = "none" ,
285
+ marker_color = t ["color" ].iloc [0 ]
286
+ )
287
+ )
288
+
289
+ return fig
290
+
291
+ def _layout (self , metric : Union [str , None ]):
271
292
try :
272
293
import dash
273
294
except ModuleNotFoundError :
274
295
self ._raise_cannot_import_dash ()
275
296
else :
276
297
from dash import Input , Output , dcc , html
277
298
278
- try :
279
- import plotly
280
- except ModuleNotFoundError :
281
- self ._raise_cannot_import_plotly ()
282
- else :
283
- from plotly import graph_objects as go
284
-
285
- fig = self .plot (metric ) if metric is not None else go .Figure ()
299
+ fig = self .plot (metric ) if metric is not None else self ._go .Figure ()
286
300
287
301
return html .Div (
288
302
[
@@ -337,15 +351,8 @@ def serve(self, metric: Union[str, None] = None, **kwargs: Any) -> None:
337
351
else :
338
352
from dash import Input , Output
339
353
340
- try :
341
- import plotly
342
- except ModuleNotFoundError :
343
- self ._raise_cannot_import_plotly ()
344
- else :
345
- from plotly import graph_objects as go
346
-
347
354
app = dash .Dash ()
348
- app .layout = self ._layout (metric )
355
+ app .layout = lambda : self ._layout (metric )
349
356
350
357
@app .callback (
351
358
Output ("viewer-title" , "children" ), Input ("history-interval" , "n_intervals" )
@@ -371,10 +378,9 @@ def update_dropdown(n_intervals):
371
378
prevent_initial_call = True ,
372
379
)
373
380
def update_history (n_intervals , metric ):
374
- self ._reload ()
375
-
376
- self ._make_table ()
377
- return self .plot (metric ) if metric is not None else go .Figure ()
381
+ self ._update ()
382
+ return (self ._update_plot (metric )
383
+ if metric is not None else self ._go .Figure ())
378
384
379
385
@app .callback (
380
386
Output ("metric-dropwdown" , "value" ),
@@ -385,6 +391,4 @@ def update_repos(name):
385
391
if isinstance (self ._container , Workspace ):
386
392
self ._container .set_default (os .path .split (name )[- 1 ])
387
393
388
- return None
389
-
390
394
app .run_server (use_reloader = False , ** kwargs )
0 commit comments