Skip to content

Commit 000b427

Browse files
authored
Merge pull request #222 from Oxid15/develop
v0.12.1 - Security fixes and minor improvements
2 parents 15ea9ab + 8a73180 commit 000b427

File tree

7 files changed

+209
-105
lines changed

7 files changed

+209
-105
lines changed

cascade/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717

18-
__version__ = "0.12.0"
18+
__version__ = "0.12.1"
1919
__author__ = "Ilia Moiseev"
2020
__author_email__ = "ilia.moiseev.5@yandex.ru"
2121

cascade/meta/history_viewer.py

+101-97
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any, Dict, List, Union
2020

2121
import pandas as pd
22-
import pendulum
2322
from deepdiff import DeepDiff
2423
from flatten_json import flatten
2524

@@ -51,14 +50,22 @@ def __init__(
5150
last_models: int, optional
5251
For each line constraints the number of models back from the last one to view
5352
"""
53+
54+
try:
55+
import plotly
56+
except ModuleNotFoundError:
57+
self._raise_cannot_import_plotly()
58+
else:
59+
from plotly import express as px
60+
from plotly import graph_objects as go
61+
62+
self._px = px
63+
self._go = go
64+
5465
self._container = container
5566
self._last_lines = last_lines
5667
self._last_models = last_models
5768

58-
self._reload()
59-
self._make_table()
60-
61-
def _reload(self) -> None:
6269
repo = self._container
6370
if isinstance(self._container, ModelLine):
6471
repo = SingleLineRepo(self._container)
@@ -72,9 +79,16 @@ def _reload(self) -> None:
7279
for repo in repos:
7380
repo.reload()
7481

82+
self._edges = dict()
7583
self._repo = repo
7684
self._repos = {repo.get_root(): repo for repo in repos}
7785

86+
self._make_table()
87+
88+
def _update(self) -> None:
89+
self._repo.reload()
90+
self._make_table()
91+
7892
def _get_last_updated_lines(self, line_names: List[str]) -> List[str]:
7993
valid_lines = []
8094
updated_at = []
@@ -113,18 +127,17 @@ def _make_table(self) -> None:
113127

114128
last_models = self._last_models if self._last_models is not None else 0
115129
for i in range(len(line))[-last_models:]:
116-
new_meta = {"line": line_root, "model": i}
130+
line_name = os.path.split(line_root)[-1]
131+
new_meta = {"line": line_name, "model": i}
117132
try:
118-
# TODO: to take only first is not good...
119133
meta = view[i][0]
134+
new_meta.update(flatten(meta))
120135
except IndexError:
121-
meta = {}
122-
123-
new_meta.update(flatten(meta))
136+
pass
124137
metas.append(new_meta)
125138

126139
p = {
127-
"line": line_root,
140+
"line": line_name,
128141
}
129142
if "params" in meta:
130143
if len(meta["params"]) > 0:
@@ -136,6 +149,23 @@ def _make_table(self) -> None:
136149
if "saved_at" in self._table:
137150
self._table = self._table.sort_values("saved_at")
138151

152+
# turn time into evenly spaced intervals
153+
time = [i for i in range(len(self._table))]
154+
lines = self._table["line"].unique()
155+
156+
cmap = self._px.colors.qualitative.Plotly
157+
cmap_len = len(self._px.colors.qualitative.Plotly)
158+
line_cols = {line: cmap[i % cmap_len] for i, line in enumerate(lines)}
159+
160+
self._table["time"] = time
161+
self._table["color"] = [line_cols[line] for line in self._table["line"]]
162+
self._table = self._table.fillna("")
163+
164+
columns2fill = [
165+
col for col in self._table.columns if not col.startswith("metrics_")
166+
]
167+
self._table = self._table.fillna({name: "" for name in columns2fill})
168+
139169
@staticmethod
140170
def _diff(p1: Dict[Any, Any], params: Dict[Any, Any]) -> List:
141171
diff = [DeepDiff(p1, p2) for p2 in params]
@@ -169,6 +199,23 @@ def _preprocess_metric(self, metric):
169199

170200
return metric
171201

202+
def _connect_points(self, line: str, metric: str, fig: Any):
203+
edges = [0]
204+
params = [p for p in self._params if p["line"] == line]
205+
for i in range(1, len(params)):
206+
diff = self._diff(params[i], params[:i])
207+
edges.append(self._specific_argmin(diff, i))
208+
209+
xs = []
210+
ys = []
211+
t = self._table.loc[self._table["line"] == line]
212+
for i, e in enumerate(edges):
213+
xs += [t["time"].iloc[i], t["time"].iloc[e], None]
214+
ys += [t[metric].iloc[i], t[metric].iloc[e], None]
215+
216+
self._edges[line] = {"edges": (xs, ys), "len": len(t)}
217+
return xs, ys
218+
172219
def plot(self, metric: str, show: bool = False) -> Any:
173220
"""
174221
Plots training history of model versions using plotly.
@@ -180,109 +227,76 @@ def plot(self, metric: str, show: bool = False) -> Any:
180227
show: bool, optional
181228
Whether to return and show or just return figure
182229
"""
183-
try:
184-
import plotly
185-
except ModuleNotFoundError:
186-
self._raise_cannot_import_plotly()
187-
else:
188-
from plotly import express as px
189-
from plotly import graph_objects as go
190-
191-
metric = self._preprocess_metric(metric)
192-
193-
# turn time into evenly spaced intervals
194-
time = [i for i in range(len(self._table))]
195-
lines = self._table["line"].unique()
196-
197-
cmap = px.colors.qualitative.Plotly
198-
cmap_len = len(px.colors.qualitative.Plotly)
199-
line_cols = {line: cmap[i % cmap_len] for i, line in enumerate(lines)}
200-
201-
self._table["time"] = time
202-
self._table["color"] = [line_cols[line] for line in self._table["line"]]
203-
table = self._table.fillna("")
204-
205-
columns2fill = [
206-
col for col in self._table.columns if not col.startswith("metrics_")
207-
]
208-
table = self._table.fillna({name: "" for name in columns2fill})
209230

210231
# plot each model against metric
211232
# with all metadata on hover
233+
metric = self._preprocess_metric(metric)
212234

213235
hover_cols = [name for name in pd.DataFrame(self._params).columns]
214-
if "saved_at" in table.columns:
236+
if "saved_at" in self._table.columns:
215237
hover_cols = ["saved_at"] + hover_cols
216238
hover_cols = ["model"] + hover_cols
217-
fig = px.scatter(table, x="time", y=metric, hover_data=hover_cols, color="line")
218-
219-
# determine connections between models
220-
# plot each one with respected color
239+
fig = self._px.scatter(self._table, x="time", y=metric, hover_data=hover_cols, color="line")
240+
lines = self._table["line"].unique()
221241

222242
for line in lines:
223-
params = [p for p in self._params if p["line"] == line]
224-
edges = []
225-
for i in range(len(params)):
226-
if i == 0:
227-
edges.append(0)
228-
continue
229-
else:
230-
diff = self._diff(params[i], params[:i])
231-
edges.append(self._specific_argmin(diff, i))
232-
233-
xs = []
234-
ys = []
235-
t = table.loc[table["line"] == line]
236-
for i, e in enumerate(edges):
237-
xs += [t["time"].iloc[i], t["time"].iloc[e], None]
238-
ys += [t[metric].iloc[i], t[metric].iloc[e], None]
243+
t = self._table.loc[self._table.line == line]
244+
self._connect_points(line, metric, fig)
239245

246+
xs, ys = self._edges[line]["edges"]
240247
fig.add_trace(
241-
go.Scatter(
248+
self._go.Scatter(
242249
x=xs,
243250
y=ys,
244251
mode="lines",
245-
marker={"color": t["color"].iloc[0]},
246252
name=line,
247253
hoverinfo="none",
254+
marker_color=t["color"].iloc[0]
248255
)
249256
)
250257

251-
# Create human-readable ticks
252-
now = pendulum.now(tz="UTC")
253-
time_text = table["saved_at"].apply(
254-
lambda t: t if t == "" else pendulum.parse(t).diff_for_humans(now)
255-
)
256-
257-
fig.update_layout(
258-
hovermode="x",
259-
xaxis=dict(
260-
tickmode="array",
261-
tickvals=[i for i in range(len(time))],
262-
ticktext=time_text,
263-
),
264-
)
265258
if show:
266259
fig.show()
267260

268261
return fig
269262

270-
def _layout(self, metric):
263+
def _update_plot(self, metric: str) -> Any:
264+
metric = self._preprocess_metric(metric)
265+
266+
hover_cols = [name for name in pd.DataFrame(self._params).columns]
267+
if "saved_at" in self._table.columns:
268+
hover_cols = ["saved_at"] + hover_cols
269+
hover_cols = ["model"] + hover_cols
270+
fig = self._px.scatter(self._table, x="time", y=metric, hover_data=hover_cols, color="line")
271+
272+
for line in sorted(self._table.line.unique()):
273+
t = self._table.loc[self._table.line == line]
274+
if line in self._edges and len(t) == self._edges[line]["len"]:
275+
xs, ys = self._edges[line]["edges"]
276+
else:
277+
xs, ys = self._connect_points(line, metric, fig)
278+
fig.add_trace(
279+
self._go.Scatter(
280+
x=xs,
281+
y=ys,
282+
mode="lines",
283+
name=line,
284+
hoverinfo="none",
285+
marker_color=t["color"].iloc[0]
286+
)
287+
)
288+
289+
return fig
290+
291+
def _layout(self, metric: Union[str, None]):
271292
try:
272293
import dash
273294
except ModuleNotFoundError:
274295
self._raise_cannot_import_dash()
275296
else:
276297
from dash import Input, Output, dcc, html
277298

278-
try:
279-
import plotly
280-
except ModuleNotFoundError:
281-
self._raise_cannot_import_plotly()
282-
else:
283-
from plotly import graph_objects as go
284-
285-
fig = self.plot(metric) if metric is not None else go.Figure()
299+
fig = self.plot(metric) if metric is not None else self._go.Figure()
286300

287301
return html.Div(
288302
[
@@ -337,15 +351,8 @@ def serve(self, metric: Union[str, None] = None, **kwargs: Any) -> None:
337351
else:
338352
from dash import Input, Output
339353

340-
try:
341-
import plotly
342-
except ModuleNotFoundError:
343-
self._raise_cannot_import_plotly()
344-
else:
345-
from plotly import graph_objects as go
346-
347354
app = dash.Dash()
348-
app.layout = self._layout(metric)
355+
app.layout = lambda: self._layout(metric)
349356

350357
@app.callback(
351358
Output("viewer-title", "children"), Input("history-interval", "n_intervals")
@@ -371,10 +378,9 @@ def update_dropdown(n_intervals):
371378
prevent_initial_call=True,
372379
)
373380
def update_history(n_intervals, metric):
374-
self._reload()
375-
376-
self._make_table()
377-
return self.plot(metric) if metric is not None else go.Figure()
381+
self._update()
382+
return (self._update_plot(metric)
383+
if metric is not None else self._go.Figure())
378384

379385
@app.callback(
380386
Output("metric-dropwdown", "value"),
@@ -385,6 +391,4 @@ def update_repos(name):
385391
if isinstance(self._container, Workspace):
386392
self._container.set_default(os.path.split(name)[-1])
387393

388-
return None
389-
390394
app.run_server(use_reloader=False, **kwargs)

cascade/models/model_repo.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
HistoryLogger,
2424
JSONEncoder,
2525
MetaHandler,
26+
MetaFromFile,
2627
PipeMeta,
2728
Traceable,
2829
TraceableOnDisk,
@@ -50,7 +51,6 @@ def __init__(
5051
self._lines = dict()
5152

5253
def reload(self) -> None:
53-
# TODO: implement full reload
5454
for line in self._lines:
5555
self._lines[line].reload()
5656

@@ -254,14 +254,60 @@ def __repr__(self) -> str:
254254

255255
def reload(self) -> None:
256256
"""
257-
Updates internal state.
257+
Updates internal state
258258
"""
259259
super().reload()
260+
self._update_lines()
260261
self._update_meta()
261262

262-
def __add__(self, repo) -> "ModelRepoConcatenator":
263+
def __add__(self, repo: "ModelRepo") -> "ModelRepoConcatenator":
263264
return ModelRepoConcatenator([self, repo])
264265

266+
def load_model_meta(self, model: str) -> MetaFromFile:
267+
"""
268+
Loads metadata of a model from disk
269+
270+
Parameters
271+
----------
272+
model : str
273+
model slug e.g. `fair_squid_of_bliss`
274+
275+
Returns
276+
-------
277+
MetaFromFile
278+
Model metadata
279+
280+
Raises
281+
------
282+
FileNotFoundError
283+
Raises if failed to find the model with slug specified
284+
"""
285+
286+
for line in self._lines.values():
287+
try:
288+
meta = line.load_model_meta(model)
289+
except FileNotFoundError:
290+
continue
291+
else:
292+
return meta
293+
raise FileNotFoundError(
294+
f"Failed to find the model {model} in the repo at {self._root}"
295+
)
296+
297+
def _update_lines(self) -> None:
298+
for name in sorted(os.listdir(self._root)):
299+
if (
300+
os.path.isdir(os.path.join(self._root, name))
301+
and name not in self._lines
302+
):
303+
self._lines[name] = ModelLine(
304+
os.path.join(self._root, name),
305+
model_cls=self._model_cls
306+
if isinstance(self._model_cls, type)
307+
else self._model_cls[name],
308+
meta_fmt=self._meta_fmt,
309+
)
310+
265311

266312
class ModelRepoConcatenator(Repo):
267313
"""

0 commit comments

Comments
 (0)