22
22
import pandas as pd
23
23
24
24
from . import Server , MetaViewer
25
- from ..models import Model , ModelRepo
25
+ from ..models import Model , ModelRepo , ModelLine , SingleLineRepo
26
26
27
27
28
28
class MetricViewer :
@@ -32,7 +32,7 @@ class MetricViewer:
32
32
As metrics it uses data from `metrics` field in models'
33
33
meta and as parameters it uses `params` field.
34
34
"""
35
- def __init__ (self , repo : ModelRepo , scope : Union [int , str , slice , None ] = None ) -> None :
35
+ def __init__ (self , repo : Union [ ModelRepo , ModelLine ] , scope : Union [int , str , slice , None ] = None ) -> None :
36
36
"""
37
37
Parameters
38
38
----------
@@ -41,6 +41,8 @@ def __init__(self, repo: ModelRepo, scope: Union[int, str, slice, None] = None)
41
41
scope: Union[int, str, slice]
42
42
Index or a name of line to view. Can be set using `__getitem__`
43
43
"""
44
+ if isinstance (repo , ModelLine ):
45
+ repo = SingleLineRepo (repo )
44
46
self ._repo = repo
45
47
self ._scope = scope
46
48
self ._metrics = []
@@ -64,33 +66,29 @@ def reload_table(self) -> None:
64
66
65
67
for name in selected_names :
66
68
line = self ._repo [name ]
67
- viewer_root = line .root
69
+ viewer_root = line .get_root ()
68
70
69
- view = MetaViewer (viewer_root , filt = {' type' : ' model' })
71
+ view = MetaViewer (viewer_root , filt = {" type" : " model" })
70
72
71
73
for i in range (len (line .model_names )):
72
74
try :
73
75
meta = view [i ][- 1 ] # Takes last model from meta
74
76
except IndexError :
75
77
meta = {}
76
78
77
- metric = {
78
- 'line' : viewer_root ,
79
- 'num' : i
80
- }
79
+ metric = {"line" : viewer_root , "num" : i }
81
80
82
- if 'created_at' in meta :
83
- metric ['created_at' ] = \
84
- pendulum .parse (meta ['created_at' ])
85
- if 'saved_at' in meta :
86
- metric ['saved' ] = \
87
- pendulum .parse (meta ['saved_at' ]) \
88
- .diff_for_humans (metric ['created_at' ])
81
+ if "created_at" in meta :
82
+ metric ["created_at" ] = pendulum .parse (meta ["created_at" ])
83
+ if "saved_at" in meta :
84
+ metric ["saved" ] = pendulum .parse (
85
+ meta ["saved_at" ]
86
+ ).diff_for_humans (metric ["created_at" ])
89
87
90
- if ' metrics' in meta :
91
- metric .update (meta [' metrics' ])
92
- if ' params' in meta :
93
- metric .update (meta [' params' ])
88
+ if " metrics" in meta :
89
+ metric .update (meta [" metrics" ])
90
+ if " params" in meta :
91
+ metric .update (meta [" params" ])
94
92
95
93
self ._metrics .append (metric )
96
94
self .table = pd .DataFrame (self ._metrics )
@@ -106,24 +104,30 @@ def plot_table(self, show: bool = False):
106
104
try :
107
105
import plotly
108
106
except ModuleNotFoundError :
109
- raise ModuleNotFoundError ('''
107
+ raise ModuleNotFoundError (
108
+ """
110
109
Cannot import plotly. It is conditional
111
110
dependency you can install it
112
- using the instructions from plotly official documentation''' )
111
+ using the instructions from plotly official documentation"""
112
+ )
113
113
else :
114
114
from plotly import graph_objects as go
115
115
116
- data = pd .DataFrame (map (flatten , self .table .to_dict ('records' )))
117
- fig = go .Figure (data = [
118
- go .Table (
119
- header = dict (values = list (data .columns ),
120
- fill_color = '#f4c9c7' ,
121
- align = 'left' ),
122
- cells = dict (values = [data [col ] for col in data .columns ],
123
- fill_color = '#bcced4' ,
124
- align = 'left' )
125
- )
126
- ])
116
+ data = pd .DataFrame (map (flatten , self .table .to_dict ("records" )))
117
+ fig = go .Figure (
118
+ data = [
119
+ go .Table (
120
+ header = dict (
121
+ values = list (data .columns ), fill_color = "#f4c9c7" , align = "left"
122
+ ),
123
+ cells = dict (
124
+ values = [data [col ] for col in data .columns ],
125
+ fill_color = "#bcced4" ,
126
+ align = "left" ,
127
+ ),
128
+ )
129
+ ]
130
+ )
127
131
if show :
128
132
fig .show ()
129
133
return fig
@@ -145,25 +149,26 @@ def get_best_by(self, metric: str, maximize: bool = True) -> Model:
145
149
TypeError if metric objects cannot be sorted. If only one model in repo, then
146
150
returns it without error since no sorting involved.
147
151
"""
148
- assert metric in self .table , f' { metric } is not in { self .table .columns } '
152
+ assert metric in self .table , f" { metric } is not in { self .table .columns } "
149
153
t = self .table .loc [self .table [metric ].notna ()]
150
154
151
155
try :
152
156
t = t .sort_values (metric , ascending = maximize )
153
157
except TypeError as e :
154
- raise TypeError (f' Metric { metric } objects cannot be sorted' ) from e
158
+ raise TypeError (f" Metric { metric } objects cannot be sorted" ) from e
155
159
156
160
best_row = t .iloc [- 1 ]
157
- name = os .path .split (best_row [' line' ])[- 1 ]
158
- num = best_row [' num' ]
161
+ name = os .path .split (best_row [" line" ])[- 1 ]
162
+ num = best_row [" num" ]
159
163
return self ._repo [name ][num ]
160
164
161
165
def serve (
162
- self ,
163
- page_size : int = 50 ,
164
- include : Union [List [str ], None ] = None ,
165
- exclude : Union [List [str ], None ] = None ,
166
- ** kwargs : Any ) -> None :
166
+ self ,
167
+ page_size : int = 50 ,
168
+ include : Union [List [str ], None ] = None ,
169
+ exclude : Union [List [str ], None ] = None ,
170
+ ** kwargs : Any ,
171
+ ) -> None :
167
172
"""
168
173
Runs dash-based server with interactive table of metrics and parameters
169
174
@@ -179,15 +184,21 @@ def serve(
179
184
**kwargs:
180
185
Arguments of dash app. Can be ip or port for example
181
186
"""
182
- server = MetricServer (self , page_size = page_size , include = include , exclude = exclude )
187
+ server = MetricServer (
188
+ self , page_size = page_size , include = include , exclude = exclude
189
+ )
183
190
server .serve (** kwargs )
184
191
185
192
186
193
class MetricServer (Server ):
187
- def __init__ (self , mv : MetricViewer ,
188
- page_size : int ,
189
- include : Union [List [str ], None ],
190
- exclude : Union [List [str ], None ], ** kwargs : Any ) -> None :
194
+ def __init__ (
195
+ self ,
196
+ mv : MetricViewer ,
197
+ page_size : int ,
198
+ include : Union [List [str ], None ],
199
+ exclude : Union [List [str ], None ],
200
+ ** kwargs : Any ,
201
+ ) -> None :
191
202
self ._mv = mv
192
203
self ._page_size = page_size
193
204
self ._include = include
@@ -202,20 +213,19 @@ def _update_graph_callback(self, _app) -> None:
202
213
from plotly import graph_objects as go
203
214
204
215
@_app .callback (
205
- Output (component_id = 'dependence-figure' , component_property = 'figure' ),
206
- Input (component_id = 'dropdown-x' , component_property = 'value' ),
207
- Input (component_id = 'dropdown-y' , component_property = 'value' ))
216
+ Output (component_id = "dependence-figure" , component_property = "figure" ),
217
+ Input (component_id = "dropdown-x" , component_property = "value" ),
218
+ Input (component_id = "dropdown-y" , component_property = "value" ),
219
+ )
208
220
def _update_graph (x , y ):
209
221
fig = go .Figure ()
210
222
if x is not None and y is not None :
211
223
fig .add_trace (
212
224
go .Scatter (
213
- x = self ._df_flatten [x ],
214
- y = self ._df_flatten [y ],
215
- mode = 'markers'
225
+ x = self ._df_flatten [x ], y = self ._df_flatten [y ], mode = "markers"
216
226
)
217
227
)
218
- fig .update_layout (title = f' { x } to { y } relation' )
228
+ fig .update_layout (title = f" { x } to { y } relation" )
219
229
return fig
220
230
221
231
def _layout (self ):
@@ -234,46 +244,45 @@ def _layout(self):
234
244
df = df .drop (self ._exclude , axis = 1 )
235
245
236
246
if self ._include is not None :
237
- df = df [[' line' , ' num' ] + self ._include ]
247
+ df = df [[" line" , " num" ] + self ._include ]
238
248
239
- self ._df_flatten = pd .DataFrame (map (flatten , df .to_dict (' records' )))
249
+ self ._df_flatten = pd .DataFrame (map (flatten , df .to_dict (" records" )))
240
250
dep_fig = go .Figure ()
241
251
242
- return html .Div ([
243
- html .H1 (
244
- children = f'MetricViewer in { self ._mv ._repo } ' ,
245
- style = {
246
- 'textAlign' : 'center' ,
247
- 'color' : '#084c61' ,
248
- 'font-family' : 'Montserrat'
249
- }
250
- ),
251
- dcc .Dropdown (
252
- list (self ._df_flatten .columns ),
253
- id = 'dropdown-x' ,
254
- multi = False ),
255
- dcc .Dropdown (
256
- list (self ._df_flatten .columns ),
257
- id = 'dropdown-y' ,
258
- multi = False ),
259
- dcc .Graph (
260
- id = 'dependence-figure' ,
261
- figure = dep_fig ),
262
- dash_table .DataTable (
263
- columns = [
264
- {'name' : col , 'id' : col , 'selectable' : True } for col in self ._df_flatten .columns
265
- ],
266
- data = self ._df_flatten .to_dict ('records' ),
267
- filter_action = "native" ,
268
- sort_action = "native" ,
269
- sort_mode = "multi" ,
270
- selected_columns = [],
271
- selected_rows = [],
272
- page_action = "native" ,
273
- page_current = 0 ,
274
- page_size = self ._page_size ,
275
- )
276
- ])
252
+ return html .Div (
253
+ [
254
+ html .H1 (
255
+ children = f"MetricViewer in { self ._mv ._repo } " ,
256
+ style = {
257
+ "textAlign" : "center" ,
258
+ "color" : "#084c61" ,
259
+ "font-family" : "Montserrat" ,
260
+ },
261
+ ),
262
+ dcc .Dropdown (
263
+ list (self ._df_flatten .columns ), id = "dropdown-x" , multi = False
264
+ ),
265
+ dcc .Dropdown (
266
+ list (self ._df_flatten .columns ), id = "dropdown-y" , multi = False
267
+ ),
268
+ dcc .Graph (id = "dependence-figure" , figure = dep_fig ),
269
+ dash_table .DataTable (
270
+ columns = [
271
+ {"name" : col , "id" : col , "selectable" : True }
272
+ for col in self ._df_flatten .columns
273
+ ],
274
+ data = self ._df_flatten .to_dict ("records" ),
275
+ filter_action = "native" ,
276
+ sort_action = "native" ,
277
+ sort_mode = "multi" ,
278
+ selected_columns = [],
279
+ selected_rows = [],
280
+ page_action = "native" ,
281
+ page_current = 0 ,
282
+ page_size = self ._page_size ,
283
+ ),
284
+ ]
285
+ )
277
286
278
287
def serve (self , ** kwargs : Any ) -> None :
279
288
# Conditional import
0 commit comments