Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit 6b61f15

Browse files
committed
Move job, models, and feature_slice_view plotting to API. (#167)
* Move job, models, and feature_slice_view plotting to API. * Follow up on CR comments.
1 parent f404e5f commit 6b61f15

File tree

5 files changed

+186
-46
lines changed

5 files changed

+186
-46
lines changed

datalab/mlalpha/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ._metadata import Metadata
2020
from ._local_predictor import LocalPredictor
2121
from ._cloud_predictor import CloudPredictor
22-
from ._job import Jobs
22+
from ._job import Jobs, Job
2323
from ._summary import Summary
2424
from ._tensorboard import TensorBoard
2525
from ._dataset import CsvDataSet, BigQueryDataSet
@@ -28,6 +28,7 @@
2828
from ._confusion_matrix import ConfusionMatrix
2929
from ._analysis import csv_to_dataframe
3030
from ._package_runner import PackageRunner
31+
from ._feature_slice_view import FeatureSliceView
3132

3233
from plotly.offline import init_notebook_mode
3334

datalab/mlalpha/_cloud_models.py

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from googleapiclient import discovery
1616
import os
1717
import time
18+
import yaml
1819

1920
import datalab.context
2021
import datalab.storage
@@ -29,38 +30,31 @@
2930
class CloudModels(object):
3031
"""Represents a list of Cloud ML models for a project."""
3132

32-
def __init__(self, project_id=None, credentials=None, api=None):
33+
def __init__(self, project_id=None):
3334
"""Initializes an instance of a CloudML Model list that is iteratable
3435
("for model in CloudModels()").
3536
3637
Args:
3738
project_id: project_id of the models. If not provided, default project_id will be used.
38-
credentials: credentials used to talk to CloudML service. If not provided, default credentials
39-
will be used.
40-
api: an optional CloudML API client.
4139
"""
4240
if project_id is None:
4341
project_id = datalab.context.Context.default().project_id
4442
self._project_id = project_id
45-
if credentials is None:
46-
credentials = datalab.context.Context.default().credentials
47-
self._credentials = credentials
48-
if api is None:
49-
api = discovery.build('ml', 'v1beta1', credentials=self._credentials,
50-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
51-
self._api = api
43+
self._credentials = datalab.context.Context.default().credentials
44+
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
45+
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
5246

5347
def _retrieve_models(self, page_token, page_size):
54-
list_info = self._api.projects().models().list(parent='projects/' + self._project_id,
55-
pageToken=page_token, pageSize=page_size).execute()
48+
list_info = self._api.projects().models().list(
49+
parent='projects/' + self._project_id, pageToken=page_token, pageSize=page_size).execute()
5650
models = list_info.get('models', [])
5751
page_token = list_info.get('nextPageToken', None)
5852
return models, page_token
5953

6054
def __iter__(self):
6155
return iter(datalab.utils.Iterator(self._retrieve_models))
6256

63-
def get(self, model_name):
57+
def get_model_details(self, model_name):
6458
"""Get details of a model.
6559
6660
Args:
@@ -95,11 +89,42 @@ def delete(self, model_name):
9589
full_name = ('projects/%s/models/%s' % (self._project_id, model_name))
9690
return self._api.projects().models().delete(name=full_name).execute()
9791

92+
def list(self, count=10):
93+
"""List models under the current project in a table view.
94+
95+
Args:
96+
count: upper limit of the number of models to list.
97+
Raises:
98+
Exception if it is called in a non-IPython environment.
99+
"""
100+
import IPython
101+
data = []
102+
# Add range(count) to loop so it will stop either it reaches count, or iteration
103+
# on self is exhausted. "self" is iterable (see __iter__() method).
104+
for _, model in zip(range(count), self):
105+
element = {'name': model['name']}
106+
if 'defaultVersion' in model:
107+
version_short_name = model['defaultVersion']['name'].split('/')[-1]
108+
element['defaultVersion'] = version_short_name
109+
data.append(element)
110+
111+
IPython.display.display(
112+
datalab.utils.commands.render_dictionary(data, ['name', 'defaultVersion']))
113+
114+
def describe(self, model_name):
115+
"""Print information of a specified model.
116+
117+
Args:
118+
model_name: the name of the model to print details on.
119+
"""
120+
model_yaml = yaml.safe_dump(self.get_model_details(model_name), default_flow_style=False)
121+
print model_yaml
122+
98123

99124
class CloudModelVersions(object):
100125
"""Represents a list of versions for a Cloud ML model."""
101126

102-
def __init__(self, model_name, project_id=None, credentials=None, api=None):
127+
def __init__(self, model_name, project_id=None):
103128
"""Initializes an instance of a CloudML model version list that is iteratable
104129
("for version in CloudModelVersions()").
105130
@@ -108,20 +133,12 @@ def __init__(self, model_name, project_id=None, credentials=None, api=None):
108133
("projects/[project_id]/models/[model_name]") or just [model_name].
109134
project_id: project_id of the models. If not provided and model_name is not a full name
110135
(not including project_id), default project_id will be used.
111-
credentials: credentials used to talk to CloudML service. If not provided, default
112-
credentials will be used.
113-
api: an optional CloudML API client.
114136
"""
115137
if project_id is None:
116-
project_id = datalab.context.Context.default().project_id
117-
self._project_id = project_id
118-
if credentials is None:
119-
credentials = datalab.context.Context.default().credentials
120-
self._credentials = credentials
121-
if api is None:
122-
api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
123-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
124-
self._api = api
138+
self._project_id = datalab.context.Context.default().project_id
139+
self._credentials = datalab.context.Context.default().credentials
140+
self._api = discovery.build('ml', 'v1alpha3', credentials=self._credentials,
141+
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
125142
if not model_name.startswith('projects/'):
126143
model_name = ('projects/%s/models/%s' % (self._project_id, model_name))
127144
self._full_model_name = model_name
@@ -138,7 +155,7 @@ def _retrieve_versions(self, page_token, page_size):
138155
def __iter__(self):
139156
return iter(datalab.utils.Iterator(self._retrieve_versions))
140157

141-
def get(self, version_name):
158+
def get_version_details(self, version_name):
142159
"""Get details of a version.
143160
144161
Args:
@@ -205,3 +222,28 @@ def delete(self, version_name):
205222
name = ('%s/versions/%s' % (self._full_model_name, version_name))
206223
response = self._api.projects().models().versions().delete(name=name).execute()
207224
self._wait_for_long_running_operation(response)
225+
226+
def describe(self, version_name):
227+
"""Print information of a specified model.
228+
229+
Args:
230+
version: the name of the version in short form, such as "v1".
231+
"""
232+
version_yaml = yaml.safe_dump(self.get_version_details(version_name),
233+
default_flow_style=False)
234+
print version_yaml
235+
236+
def list(self):
237+
"""List versions under the current model in a table view.
238+
239+
Raises:
240+
Exception if it is called in a non-IPython environment.
241+
"""
242+
import IPython
243+
244+
# "self" is iterable (see __iter__() method).
245+
data = [{'name': version['name'].split()[-1],
246+
'deploymentUri': version['deploymentUri'], 'createTime': version['createTime']}
247+
for version in self]
248+
IPython.display.display(
249+
datalab.utils.commands.render_dictionary(data, ['name', 'deploymentUri', 'createTime']))
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2017 Google Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4+
# in compliance with the License. You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software distributed under the License
9+
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10+
# or implied. See the License for the specific language governing permissions and limitations under
11+
# the License.
12+
13+
import json
14+
import pandas as pd
15+
from types import ModuleType
16+
17+
import datalab.data
18+
import datalab.utils
19+
20+
21+
class FeatureSliceView(object):
22+
"""Represents A feature slice view."""
23+
24+
def _get_lantern_format(self, df):
25+
""" Feature slice view browser expects data in the format of:
26+
{"metricValues": {"count": 12, "accuracy": 1.0}, "feature": "species:Iris-setosa"}
27+
{"metricValues": {"count": 11, "accuracy": 0.72}, "feature": "species:Iris-versicolor"}
28+
...
29+
This function converts a DataFrame to such format.
30+
"""
31+
32+
if ('count' not in df) or ('feature' not in df):
33+
raise Exception('No "count" or "feature" found in data.')
34+
if len(df.columns) < 3:
35+
raise Exception('Need at least one metrics column.')
36+
if len(df) == 0:
37+
raise Exception('Data is empty')
38+
39+
data = []
40+
for _, row in df.iterrows():
41+
metric_values = dict(row)
42+
feature = metric_values.pop('feature')
43+
data.append({'feature': feature, 'metricValues': metric_values})
44+
return data
45+
46+
def plot(self, data):
47+
""" Plots a featire slice view on given data.
48+
49+
Args:
50+
data: Can be one of:
51+
A string of sql query.
52+
A sql query module defined by "%%sql --module module_name".
53+
A pandas DataFrame.
54+
Regardless of data type, it must include the following columns:
55+
"feature": identifies a slice of features. For example: "petal_length:4.0-4.2".
56+
"count": number of instances in that slice of features.
57+
All other columns are viewed as metrics for its feature slice. At least one is required.
58+
"""
59+
import IPython
60+
61+
if isinstance(data, ModuleType) or isinstance(data, basestring):
62+
item, _ = datalab.data.SqlModule.get_sql_statement_with_environment(data, {})
63+
query = datalab.bigquery.Query(item)
64+
df = query.results().to_dataframe()
65+
data = self._get_lantern_format(df)
66+
elif isinstance(data, pd.core.frame.DataFrame):
67+
data = self._get_lantern_format(data)
68+
else:
69+
raise Exception('data needs to be a sql query, or a pandas DataFrame.')
70+
71+
HTML_TEMPLATE = """<link rel="import" href="/nbextensions/gcpdatalab/extern/lantern-browser.html" >
72+
<lantern-browser id="{html_id}"></lantern-browser>
73+
<script>
74+
var browser = document.querySelector('#{html_id}');
75+
browser.metrics = {metrics};
76+
browser.data = {data};
77+
browser.sourceType = 'colab';
78+
browser.weightedExamplesColumn = 'count';
79+
browser.calibrationPlotUriFn = function(s) {{ return '/' + s; }}
80+
</script>"""
81+
# Serialize the data and list of metrics names to JSON string.
82+
metrics_str = str(map(str, data[0]['metricValues'].keys()))
83+
data_str = str([{str(k): json.dumps(v) for k,v in elem.iteritems()} for elem in data])
84+
html_id = 'l' + datalab.utils.commands.Html.next_id()
85+
html = HTML_TEMPLATE.format(html_id=html_id, metrics=metrics_str, data=data_str)
86+
IPython.display.display(IPython.display.HTML(html))
87+

datalab/mlalpha/_job.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,20 @@
1212

1313
"""Implements Cloud ML Operation wrapper."""
1414

15+
1516
import datalab.utils
1617
import datalab.context
1718
from googleapiclient import discovery
19+
import yaml
20+
21+
# TODO(qimingj) Remove once the API is public since it will no longer be needed
22+
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
23+
'ml_v1beta1_discovery.json'
1824

25+
import datalab.utils
26+
import datalab.context
27+
from googleapiclient import discovery
28+
import yaml
1929

2030
# TODO(qimingj) Remove once the API is public since it will no longer be needed
2131
_CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
@@ -54,26 +64,26 @@ def refresh(self):
5464
""" Refresh the job info. """
5565
self._info = self._api.projects().jobs().get(name=self._name).execute()
5666

67+
def describe(self):
68+
job_yaml = yaml.safe_dump(self._info, default_flow_style=False)
69+
print job_yaml
70+
5771

5872
class Jobs(object):
5973
"""Represents a list of Cloud ML jobs for a project."""
6074

61-
def __init__(self, filter=None, context=None, api=None):
75+
def __init__(self, filter=None):
6276
"""Initializes an instance of a CloudML Job list that is iteratable ("for job in jobs()").
6377
6478
Args:
65-
filter: filter string for retrieving jobs. Currently only "done=true|false" is supported.
79+
filter: filter string for retrieving jobs, such as "state=FAILED"
6680
context: an optional Context object providing project_id and credentials.
6781
api: an optional CloudML API client.
6882
"""
6983
self._filter = filter
70-
if context is None:
71-
context = datalab.context.Context.default()
72-
self._context = context
73-
if api is None:
74-
api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
75-
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
76-
self._api = api
84+
self._context = datalab.context.Context.default()
85+
self._api = discovery.build('ml', 'v1beta1', credentials=self._context.credentials,
86+
discoveryServiceUrl=_CLOUDML_DISCOVERY_URL)
7787

7888
def _retrieve_jobs(self, page_token, page_size):
7989
list_info = self._api.projects().jobs().list(parent='projects/' + self._context.project_id,
@@ -86,10 +96,10 @@ def _retrieve_jobs(self, page_token, page_size):
8696
def __iter__(self):
8797
return iter(datalab.utils.Iterator(self._retrieve_jobs))
8898

89-
def get_job_by_name(self, name):
90-
""" get a CloudML job by its name.
91-
Args:
92-
name: the name of the job. See "Job" class constructor.
93-
"""
94-
return Job(name, self._context, self._api)
95-
99+
def list(self, count=10):
100+
import IPython
101+
data = [{'Id': job['jobId'], 'State': job.get('state', 'UNKNOWN'),
102+
'createTime': job['createTime']}
103+
for _, job in zip(range(count), self)]
104+
IPython.display.display(
105+
datalab.utils.commands.render_dictionary(data, ['Id', 'State', 'createTime']))

datalab/mlalpha/commands/_mlalpha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _model(args, _):
618618
return
619619
elif len(parts) == 2:
620620
versions = datalab.mlalpha.CloudModelVersions(parts[0], project_id=args['project'])
621-
version_yaml = yaml.safe_dump(versions.get(parts[1]))
621+
version_yaml = yaml.safe_dump(versions.get_version_details(parts[1]))
622622
return datalab.utils.commands.render_text(version_yaml, preformatted=True)
623623
else:
624624
raise Exception('Too many "." in name. Use "model" or "model.version".')

0 commit comments

Comments
 (0)