Skip to content

Unify repos and lines #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 27, 2023
13 changes: 8 additions & 5 deletions cascade/meta/history_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flatten_json import flatten
from deepdiff import DeepDiff

from ..models import ModelRepo
from ..models import ModelRepo, ModelLine, SingleLineRepo
from . import Server, MetaViewer


Expand All @@ -34,7 +34,7 @@ class HistoryViewer(Server):

def __init__(
self,
repo: ModelRepo,
repo: Union[ModelRepo, ModelLine],
last_lines: Union[int, None] = None,
last_models: Union[int, None] = None,
) -> None:
Expand All @@ -48,6 +48,8 @@ def __init__(
last_models: int, optional
For each line constraints the number of models back from the last one to view
"""
if isinstance(repo, ModelLine):
repo = SingleLineRepo(repo)
self._repo = repo
self._last_lines = last_lines
self._last_models = last_models
Expand All @@ -64,11 +66,12 @@ def _make_table(self) -> None:

for line_name in line_names:
line = self._repo[line_name]
view = MetaViewer(line.root, filt={"type": "model"})
line_root = line.get_root()
view = MetaViewer(line_root, filt={"type": "model"})

last_models = self._last_models if self._last_models is not None else 0
for i in range(len(line))[-last_models:]:
new_meta = {"line": line.root, "model": i}
new_meta = {"line": line_root, "model": i}
try:
# TODO: to take only first is not good...
meta = view[i][0]
Expand All @@ -79,7 +82,7 @@ def _make_table(self) -> None:
metas.append(new_meta)

p = {
"line": line.root,
"line": line_root,
}
if "params" in meta:
if len(meta["params"]) > 0:
Expand Down
189 changes: 99 additions & 90 deletions cascade/meta/metric_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pandas as pd

from . import Server, MetaViewer
from ..models import Model, ModelRepo
from ..models import Model, ModelRepo, ModelLine, SingleLineRepo


class MetricViewer:
Expand All @@ -32,7 +32,7 @@ class MetricViewer:
As metrics it uses data from `metrics` field in models'
meta and as parameters it uses `params` field.
"""
def __init__(self, repo: ModelRepo, scope: Union[int, str, slice, None] = None) -> None:
def __init__(self, repo: Union[ModelRepo, ModelLine], scope: Union[int, str, slice, None] = None) -> None:
"""
Parameters
----------
Expand All @@ -41,6 +41,8 @@ def __init__(self, repo: ModelRepo, scope: Union[int, str, slice, None] = None)
scope: Union[int, str, slice]
Index or a name of line to view. Can be set using `__getitem__`
"""
if isinstance(repo, ModelLine):
repo = SingleLineRepo(repo)
self._repo = repo
self._scope = scope
self._metrics = []
Expand All @@ -64,33 +66,29 @@ def reload_table(self) -> None:

for name in selected_names:
line = self._repo[name]
viewer_root = line.root
viewer_root = line.get_root()

view = MetaViewer(viewer_root, filt={'type': 'model'})
view = MetaViewer(viewer_root, filt={"type": "model"})

for i in range(len(line.model_names)):
try:
meta = view[i][-1] # Takes last model from meta
except IndexError:
meta = {}

metric = {
'line': viewer_root,
'num': i
}
metric = {"line": viewer_root, "num": i}

if 'created_at' in meta:
metric['created_at'] = \
pendulum.parse(meta['created_at'])
if 'saved_at' in meta:
metric['saved'] = \
pendulum.parse(meta['saved_at']) \
.diff_for_humans(metric['created_at'])
if "created_at" in meta:
metric["created_at"] = pendulum.parse(meta["created_at"])
if "saved_at" in meta:
metric["saved"] = pendulum.parse(
meta["saved_at"]
).diff_for_humans(metric["created_at"])

if 'metrics' in meta:
metric.update(meta['metrics'])
if 'params' in meta:
metric.update(meta['params'])
if "metrics" in meta:
metric.update(meta["metrics"])
if "params" in meta:
metric.update(meta["params"])

self._metrics.append(metric)
self.table = pd.DataFrame(self._metrics)
Expand All @@ -106,24 +104,30 @@ def plot_table(self, show: bool = False):
try:
import plotly
except ModuleNotFoundError:
raise ModuleNotFoundError('''
raise ModuleNotFoundError(
"""
Cannot import plotly. It is conditional
dependency you can install it
using the instructions from plotly official documentation''')
using the instructions from plotly official documentation"""
)
else:
from plotly import graph_objects as go

data = pd.DataFrame(map(flatten, self.table.to_dict('records')))
fig = go.Figure(data=[
go.Table(
header=dict(values=list(data.columns),
fill_color='#f4c9c7',
align='left'),
cells=dict(values=[data[col] for col in data.columns],
fill_color='#bcced4',
align='left')
)
])
data = pd.DataFrame(map(flatten, self.table.to_dict("records")))
fig = go.Figure(
data=[
go.Table(
header=dict(
values=list(data.columns), fill_color="#f4c9c7", align="left"
),
cells=dict(
values=[data[col] for col in data.columns],
fill_color="#bcced4",
align="left",
),
)
]
)
if show:
fig.show()
return fig
Expand All @@ -145,25 +149,26 @@ def get_best_by(self, metric: str, maximize: bool = True) -> Model:
TypeError if metric objects cannot be sorted. If only one model in repo, then
returns it without error since no sorting involved.
"""
assert metric in self.table, f'{metric} is not in {self.table.columns}'
assert metric in self.table, f"{metric} is not in {self.table.columns}"
t = self.table.loc[self.table[metric].notna()]

try:
t = t.sort_values(metric, ascending=maximize)
except TypeError as e:
raise TypeError(f'Metric {metric} objects cannot be sorted') from e
raise TypeError(f"Metric {metric} objects cannot be sorted") from e

best_row = t.iloc[-1]
name = os.path.split(best_row['line'])[-1]
num = best_row['num']
name = os.path.split(best_row["line"])[-1]
num = best_row["num"]
return self._repo[name][num]

def serve(
self,
page_size: int = 50,
include: Union[List[str], None] = None,
exclude: Union[List[str], None] = None,
**kwargs: Any) -> None:
self,
page_size: int = 50,
include: Union[List[str], None] = None,
exclude: Union[List[str], None] = None,
**kwargs: Any,
) -> None:
"""
Runs dash-based server with interactive table of metrics and parameters

Expand All @@ -179,15 +184,21 @@ def serve(
**kwargs:
Arguments of dash app. Can be ip or port for example
"""
server = MetricServer(self, page_size=page_size, include=include, exclude=exclude)
server = MetricServer(
self, page_size=page_size, include=include, exclude=exclude
)
server.serve(**kwargs)


class MetricServer(Server):
def __init__(self, mv: MetricViewer,
page_size: int,
include: Union[List[str], None],
exclude: Union[List[str], None], **kwargs: Any) -> None:
def __init__(
self,
mv: MetricViewer,
page_size: int,
include: Union[List[str], None],
exclude: Union[List[str], None],
**kwargs: Any,
) -> None:
self._mv = mv
self._page_size = page_size
self._include = include
Expand All @@ -202,20 +213,19 @@ def _update_graph_callback(self, _app) -> None:
from plotly import graph_objects as go

@_app.callback(
Output(component_id='dependence-figure', component_property='figure'),
Input(component_id='dropdown-x', component_property='value'),
Input(component_id='dropdown-y', component_property='value'))
Output(component_id="dependence-figure", component_property="figure"),
Input(component_id="dropdown-x", component_property="value"),
Input(component_id="dropdown-y", component_property="value"),
)
def _update_graph(x, y):
fig = go.Figure()
if x is not None and y is not None:
fig.add_trace(
go.Scatter(
x=self._df_flatten[x],
y=self._df_flatten[y],
mode='markers'
x=self._df_flatten[x], y=self._df_flatten[y], mode="markers"
)
)
fig.update_layout(title=f'{x} to {y} relation')
fig.update_layout(title=f"{x} to {y} relation")
return fig

def _layout(self):
Expand All @@ -234,46 +244,45 @@ def _layout(self):
df = df.drop(self._exclude, axis=1)

if self._include is not None:
df = df[['line', 'num'] + self._include]
df = df[["line", "num"] + self._include]

self._df_flatten = pd.DataFrame(map(flatten, df.to_dict('records')))
self._df_flatten = pd.DataFrame(map(flatten, df.to_dict("records")))
dep_fig = go.Figure()

return html.Div([
html.H1(
children=f'MetricViewer in {self._mv._repo}',
style={
'textAlign': 'center',
'color': '#084c61',
'font-family': 'Montserrat'
}
),
dcc.Dropdown(
list(self._df_flatten.columns),
id='dropdown-x',
multi=False),
dcc.Dropdown(
list(self._df_flatten.columns),
id='dropdown-y',
multi=False),
dcc.Graph(
id='dependence-figure',
figure=dep_fig),
dash_table.DataTable(
columns=[
{'name': col, 'id': col, 'selectable': True} for col in self._df_flatten.columns
],
data=self._df_flatten.to_dict('records'),
filter_action="native",
sort_action="native",
sort_mode="multi",
selected_columns=[],
selected_rows=[],
page_action="native",
page_current=0,
page_size=self._page_size,
)
])
return html.Div(
[
html.H1(
children=f"MetricViewer in {self._mv._repo}",
style={
"textAlign": "center",
"color": "#084c61",
"font-family": "Montserrat",
},
),
dcc.Dropdown(
list(self._df_flatten.columns), id="dropdown-x", multi=False
),
dcc.Dropdown(
list(self._df_flatten.columns), id="dropdown-y", multi=False
),
dcc.Graph(id="dependence-figure", figure=dep_fig),
dash_table.DataTable(
columns=[
{"name": col, "id": col, "selectable": True}
for col in self._df_flatten.columns
],
data=self._df_flatten.to_dict("records"),
filter_action="native",
sort_action="native",
sort_mode="multi",
selected_columns=[],
selected_rows=[],
page_action="native",
page_current=0,
page_size=self._page_size,
),
]
)

def serve(self, **kwargs: Any) -> None:
# Conditional import
Expand Down
2 changes: 1 addition & 1 deletion cascade/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@

from .model import Model, ModelModifier
from .basic_model import BasicModel, BasicModelModifier
from .model_repo import ModelRepo
from .model_repo import ModelRepo, SingleLineRepo
from .model_line import ModelLine
from .trainer import Trainer, BasicTrainer
Loading