Skip to content

Commit 9b9e12a

Browse files
authored
Merge pull request #166 from Oxid15/unify_repos_and_lines
Unify repos and lines
2 parents 6f4397e + 3316e39 commit 9b9e12a

File tree

6 files changed

+263
-201
lines changed

6 files changed

+263
-201
lines changed

cascade/meta/history_viewer.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from flatten_json import flatten
2222
from deepdiff import DeepDiff
2323

24-
from ..models import ModelRepo
24+
from ..models import ModelRepo, ModelLine, SingleLineRepo
2525
from . import Server, MetaViewer
2626

2727

@@ -34,7 +34,7 @@ class HistoryViewer(Server):
3434

3535
def __init__(
3636
self,
37-
repo: ModelRepo,
37+
repo: Union[ModelRepo, ModelLine],
3838
last_lines: Union[int, None] = None,
3939
last_models: Union[int, None] = None,
4040
) -> None:
@@ -48,6 +48,8 @@ def __init__(
4848
last_models: int, optional
4949
For each line constraints the number of models back from the last one to view
5050
"""
51+
if isinstance(repo, ModelLine):
52+
repo = SingleLineRepo(repo)
5153
self._repo = repo
5254
self._last_lines = last_lines
5355
self._last_models = last_models
@@ -64,11 +66,12 @@ def _make_table(self) -> None:
6466

6567
for line_name in line_names:
6668
line = self._repo[line_name]
67-
view = MetaViewer(line.root, filt={"type": "model"})
69+
line_root = line.get_root()
70+
view = MetaViewer(line_root, filt={"type": "model"})
6871

6972
last_models = self._last_models if self._last_models is not None else 0
7073
for i in range(len(line))[-last_models:]:
71-
new_meta = {"line": line.root, "model": i}
74+
new_meta = {"line": line_root, "model": i}
7275
try:
7376
# TODO: to take only first is not good...
7477
meta = view[i][0]
@@ -79,7 +82,7 @@ def _make_table(self) -> None:
7982
metas.append(new_meta)
8083

8184
p = {
82-
"line": line.root,
85+
"line": line_root,
8386
}
8487
if "params" in meta:
8588
if len(meta["params"]) > 0:

cascade/meta/metric_viewer.py

+99-90
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pandas as pd
2323

2424
from . import Server, MetaViewer
25-
from ..models import Model, ModelRepo
25+
from ..models import Model, ModelRepo, ModelLine, SingleLineRepo
2626

2727

2828
class MetricViewer:
@@ -32,7 +32,7 @@ class MetricViewer:
3232
As metrics it uses data from `metrics` field in models'
3333
meta and as parameters it uses `params` field.
3434
"""
35-
def __init__(self, repo: ModelRepo, scope: Union[int, str, slice, None] = None) -> None:
35+
def __init__(self, repo: Union[ModelRepo, ModelLine], scope: Union[int, str, slice, None] = None) -> None:
3636
"""
3737
Parameters
3838
----------
@@ -41,6 +41,8 @@ def __init__(self, repo: ModelRepo, scope: Union[int, str, slice, None] = None)
4141
scope: Union[int, str, slice]
4242
Index or a name of line to view. Can be set using `__getitem__`
4343
"""
44+
if isinstance(repo, ModelLine):
45+
repo = SingleLineRepo(repo)
4446
self._repo = repo
4547
self._scope = scope
4648
self._metrics = []
@@ -64,33 +66,29 @@ def reload_table(self) -> None:
6466

6567
for name in selected_names:
6668
line = self._repo[name]
67-
viewer_root = line.root
69+
viewer_root = line.get_root()
6870

69-
view = MetaViewer(viewer_root, filt={'type': 'model'})
71+
view = MetaViewer(viewer_root, filt={"type": "model"})
7072

7173
for i in range(len(line.model_names)):
7274
try:
7375
meta = view[i][-1] # Takes last model from meta
7476
except IndexError:
7577
meta = {}
7678

77-
metric = {
78-
'line': viewer_root,
79-
'num': i
80-
}
79+
metric = {"line": viewer_root, "num": i}
8180

82-
if 'created_at' in meta:
83-
metric['created_at'] = \
84-
pendulum.parse(meta['created_at'])
85-
if 'saved_at' in meta:
86-
metric['saved'] = \
87-
pendulum.parse(meta['saved_at']) \
88-
.diff_for_humans(metric['created_at'])
81+
if "created_at" in meta:
82+
metric["created_at"] = pendulum.parse(meta["created_at"])
83+
if "saved_at" in meta:
84+
metric["saved"] = pendulum.parse(
85+
meta["saved_at"]
86+
).diff_for_humans(metric["created_at"])
8987

90-
if 'metrics' in meta:
91-
metric.update(meta['metrics'])
92-
if 'params' in meta:
93-
metric.update(meta['params'])
88+
if "metrics" in meta:
89+
metric.update(meta["metrics"])
90+
if "params" in meta:
91+
metric.update(meta["params"])
9492

9593
self._metrics.append(metric)
9694
self.table = pd.DataFrame(self._metrics)
@@ -106,24 +104,30 @@ def plot_table(self, show: bool = False):
106104
try:
107105
import plotly
108106
except ModuleNotFoundError:
109-
raise ModuleNotFoundError('''
107+
raise ModuleNotFoundError(
108+
"""
110109
Cannot import plotly. It is conditional
111110
dependency you can install it
112-
using the instructions from plotly official documentation''')
111+
using the instructions from plotly official documentation"""
112+
)
113113
else:
114114
from plotly import graph_objects as go
115115

116-
data = pd.DataFrame(map(flatten, self.table.to_dict('records')))
117-
fig = go.Figure(data=[
118-
go.Table(
119-
header=dict(values=list(data.columns),
120-
fill_color='#f4c9c7',
121-
align='left'),
122-
cells=dict(values=[data[col] for col in data.columns],
123-
fill_color='#bcced4',
124-
align='left')
125-
)
126-
])
116+
data = pd.DataFrame(map(flatten, self.table.to_dict("records")))
117+
fig = go.Figure(
118+
data=[
119+
go.Table(
120+
header=dict(
121+
values=list(data.columns), fill_color="#f4c9c7", align="left"
122+
),
123+
cells=dict(
124+
values=[data[col] for col in data.columns],
125+
fill_color="#bcced4",
126+
align="left",
127+
),
128+
)
129+
]
130+
)
127131
if show:
128132
fig.show()
129133
return fig
@@ -145,25 +149,26 @@ def get_best_by(self, metric: str, maximize: bool = True) -> Model:
145149
TypeError if metric objects cannot be sorted. If only one model in repo, then
146150
returns it without error since no sorting involved.
147151
"""
148-
assert metric in self.table, f'{metric} is not in {self.table.columns}'
152+
assert metric in self.table, f"{metric} is not in {self.table.columns}"
149153
t = self.table.loc[self.table[metric].notna()]
150154

151155
try:
152156
t = t.sort_values(metric, ascending=maximize)
153157
except TypeError as e:
154-
raise TypeError(f'Metric {metric} objects cannot be sorted') from e
158+
raise TypeError(f"Metric {metric} objects cannot be sorted") from e
155159

156160
best_row = t.iloc[-1]
157-
name = os.path.split(best_row['line'])[-1]
158-
num = best_row['num']
161+
name = os.path.split(best_row["line"])[-1]
162+
num = best_row["num"]
159163
return self._repo[name][num]
160164

161165
def serve(
162-
self,
163-
page_size: int = 50,
164-
include: Union[List[str], None] = None,
165-
exclude: Union[List[str], None] = None,
166-
**kwargs: Any) -> None:
166+
self,
167+
page_size: int = 50,
168+
include: Union[List[str], None] = None,
169+
exclude: Union[List[str], None] = None,
170+
**kwargs: Any,
171+
) -> None:
167172
"""
168173
Runs dash-based server with interactive table of metrics and parameters
169174
@@ -179,15 +184,21 @@ def serve(
179184
**kwargs:
180185
Arguments of dash app. Can be ip or port for example
181186
"""
182-
server = MetricServer(self, page_size=page_size, include=include, exclude=exclude)
187+
server = MetricServer(
188+
self, page_size=page_size, include=include, exclude=exclude
189+
)
183190
server.serve(**kwargs)
184191

185192

186193
class MetricServer(Server):
187-
def __init__(self, mv: MetricViewer,
188-
page_size: int,
189-
include: Union[List[str], None],
190-
exclude: Union[List[str], None], **kwargs: Any) -> None:
194+
def __init__(
195+
self,
196+
mv: MetricViewer,
197+
page_size: int,
198+
include: Union[List[str], None],
199+
exclude: Union[List[str], None],
200+
**kwargs: Any,
201+
) -> None:
191202
self._mv = mv
192203
self._page_size = page_size
193204
self._include = include
@@ -202,20 +213,19 @@ def _update_graph_callback(self, _app) -> None:
202213
from plotly import graph_objects as go
203214

204215
@_app.callback(
205-
Output(component_id='dependence-figure', component_property='figure'),
206-
Input(component_id='dropdown-x', component_property='value'),
207-
Input(component_id='dropdown-y', component_property='value'))
216+
Output(component_id="dependence-figure", component_property="figure"),
217+
Input(component_id="dropdown-x", component_property="value"),
218+
Input(component_id="dropdown-y", component_property="value"),
219+
)
208220
def _update_graph(x, y):
209221
fig = go.Figure()
210222
if x is not None and y is not None:
211223
fig.add_trace(
212224
go.Scatter(
213-
x=self._df_flatten[x],
214-
y=self._df_flatten[y],
215-
mode='markers'
225+
x=self._df_flatten[x], y=self._df_flatten[y], mode="markers"
216226
)
217227
)
218-
fig.update_layout(title=f'{x} to {y} relation')
228+
fig.update_layout(title=f"{x} to {y} relation")
219229
return fig
220230

221231
def _layout(self):
@@ -234,46 +244,45 @@ def _layout(self):
234244
df = df.drop(self._exclude, axis=1)
235245

236246
if self._include is not None:
237-
df = df[['line', 'num'] + self._include]
247+
df = df[["line", "num"] + self._include]
238248

239-
self._df_flatten = pd.DataFrame(map(flatten, df.to_dict('records')))
249+
self._df_flatten = pd.DataFrame(map(flatten, df.to_dict("records")))
240250
dep_fig = go.Figure()
241251

242-
return html.Div([
243-
html.H1(
244-
children=f'MetricViewer in {self._mv._repo}',
245-
style={
246-
'textAlign': 'center',
247-
'color': '#084c61',
248-
'font-family': 'Montserrat'
249-
}
250-
),
251-
dcc.Dropdown(
252-
list(self._df_flatten.columns),
253-
id='dropdown-x',
254-
multi=False),
255-
dcc.Dropdown(
256-
list(self._df_flatten.columns),
257-
id='dropdown-y',
258-
multi=False),
259-
dcc.Graph(
260-
id='dependence-figure',
261-
figure=dep_fig),
262-
dash_table.DataTable(
263-
columns=[
264-
{'name': col, 'id': col, 'selectable': True} for col in self._df_flatten.columns
265-
],
266-
data=self._df_flatten.to_dict('records'),
267-
filter_action="native",
268-
sort_action="native",
269-
sort_mode="multi",
270-
selected_columns=[],
271-
selected_rows=[],
272-
page_action="native",
273-
page_current=0,
274-
page_size=self._page_size,
275-
)
276-
])
252+
return html.Div(
253+
[
254+
html.H1(
255+
children=f"MetricViewer in {self._mv._repo}",
256+
style={
257+
"textAlign": "center",
258+
"color": "#084c61",
259+
"font-family": "Montserrat",
260+
},
261+
),
262+
dcc.Dropdown(
263+
list(self._df_flatten.columns), id="dropdown-x", multi=False
264+
),
265+
dcc.Dropdown(
266+
list(self._df_flatten.columns), id="dropdown-y", multi=False
267+
),
268+
dcc.Graph(id="dependence-figure", figure=dep_fig),
269+
dash_table.DataTable(
270+
columns=[
271+
{"name": col, "id": col, "selectable": True}
272+
for col in self._df_flatten.columns
273+
],
274+
data=self._df_flatten.to_dict("records"),
275+
filter_action="native",
276+
sort_action="native",
277+
sort_mode="multi",
278+
selected_columns=[],
279+
selected_rows=[],
280+
page_action="native",
281+
page_current=0,
282+
page_size=self._page_size,
283+
),
284+
]
285+
)
277286

278287
def serve(self, **kwargs: Any) -> None:
279288
# Conditional import

cascade/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616

1717
from .model import Model, ModelModifier
1818
from .basic_model import BasicModel, BasicModelModifier
19-
from .model_repo import ModelRepo
19+
from .model_repo import ModelRepo, SingleLineRepo
2020
from .model_line import ModelLine
2121
from .trainer import Trainer, BasicTrainer

0 commit comments

Comments
 (0)