15
15
from googleapiclient import discovery
16
16
import os
17
17
import time
18
+ import yaml
18
19
19
20
import datalab .context
20
21
import datalab .storage
29
30
class CloudModels (object ):
30
31
"""Represents a list of Cloud ML models for a project."""
31
32
32
- def __init__ (self , project_id = None , credentials = None , api = None ):
33
+ def __init__ (self , project_id = None ):
33
34
"""Initializes an instance of a CloudML Model list that is iteratable
34
35
("for model in CloudModels()").
35
36
36
37
Args:
37
38
project_id: project_id of the models. If not provided, default project_id will be used.
38
- credentials: credentials used to talk to CloudML service. If not provided, default credentials
39
- will be used.
40
- api: an optional CloudML API client.
41
39
"""
42
40
if project_id is None :
43
41
project_id = datalab .context .Context .default ().project_id
44
42
self ._project_id = project_id
45
- if credentials is None :
46
- credentials = datalab .context .Context .default ().credentials
47
- self ._credentials = credentials
48
- if api is None :
49
- api = discovery .build ('ml' , 'v1beta1' , credentials = self ._credentials ,
50
- discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
51
- self ._api = api
43
+ self ._credentials = datalab .context .Context .default ().credentials
44
+ self ._api = discovery .build ('ml' , 'v1alpha3' , credentials = self ._credentials ,
45
+ discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
52
46
53
47
def _retrieve_models (self , page_token , page_size ):
54
- list_info = self ._api .projects ().models ().list (parent = 'projects/' + self . _project_id ,
55
- pageToken = page_token , pageSize = page_size ).execute ()
48
+ list_info = self ._api .projects ().models ().list (
49
+ parent = 'projects/' + self . _project_id , pageToken = page_token , pageSize = page_size ).execute ()
56
50
models = list_info .get ('models' , [])
57
51
page_token = list_info .get ('nextPageToken' , None )
58
52
return models , page_token
59
53
60
54
def __iter__ (self ):
61
55
return iter (datalab .utils .Iterator (self ._retrieve_models ))
62
56
63
- def get (self , model_name ):
57
+ def get_model_details (self , model_name ):
64
58
"""Get details of a model.
65
59
66
60
Args:
@@ -95,11 +89,42 @@ def delete(self, model_name):
95
89
full_name = ('projects/%s/models/%s' % (self ._project_id , model_name ))
96
90
return self ._api .projects ().models ().delete (name = full_name ).execute ()
97
91
92
+ def list (self , count = 10 ):
93
+ """List models under the current project in a table view.
94
+
95
+ Args:
96
+ count: upper limit of the number of models to list.
97
+ Raises:
98
+ Exception if it is called in a non-IPython environment.
99
+ """
100
+ import IPython
101
+ data = []
102
+ # Add range(count) to loop so it will stop either it reaches count, or iteration
103
+ # on self is exhausted. "self" is iterable (see __iter__() method).
104
+ for _ , model in zip (range (count ), self ):
105
+ element = {'name' : model ['name' ]}
106
+ if 'defaultVersion' in model :
107
+ version_short_name = model ['defaultVersion' ]['name' ].split ('/' )[- 1 ]
108
+ element ['defaultVersion' ] = version_short_name
109
+ data .append (element )
110
+
111
+ IPython .display .display (
112
+ datalab .utils .commands .render_dictionary (data , ['name' , 'defaultVersion' ]))
113
+
114
+ def describe (self , model_name ):
115
+ """Print information of a specified model.
116
+
117
+ Args:
118
+ model_name: the name of the model to print details on.
119
+ """
120
+ model_yaml = yaml .safe_dump (self .get_model_details (model_name ), default_flow_style = False )
121
+ print model_yaml
122
+
98
123
99
124
class CloudModelVersions (object ):
100
125
"""Represents a list of versions for a Cloud ML model."""
101
126
102
- def __init__ (self , model_name , project_id = None , credentials = None , api = None ):
127
+ def __init__ (self , model_name , project_id = None ):
103
128
"""Initializes an instance of a CloudML model version list that is iteratable
104
129
("for version in CloudModelVersions()").
105
130
@@ -108,20 +133,12 @@ def __init__(self, model_name, project_id=None, credentials=None, api=None):
108
133
("projects/[project_id]/models/[model_name]") or just [model_name].
109
134
project_id: project_id of the models. If not provided and model_name is not a full name
110
135
(not including project_id), default project_id will be used.
111
- credentials: credentials used to talk to CloudML service. If not provided, default
112
- credentials will be used.
113
- api: an optional CloudML API client.
114
136
"""
115
137
if project_id is None :
116
- project_id = datalab .context .Context .default ().project_id
117
- self ._project_id = project_id
118
- if credentials is None :
119
- credentials = datalab .context .Context .default ().credentials
120
- self ._credentials = credentials
121
- if api is None :
122
- api = discovery .build ('ml' , 'v1alpha3' , credentials = self ._credentials ,
123
- discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
124
- self ._api = api
138
+ self ._project_id = datalab .context .Context .default ().project_id
139
+ self ._credentials = datalab .context .Context .default ().credentials
140
+ self ._api = discovery .build ('ml' , 'v1alpha3' , credentials = self ._credentials ,
141
+ discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
125
142
if not model_name .startswith ('projects/' ):
126
143
model_name = ('projects/%s/models/%s' % (self ._project_id , model_name ))
127
144
self ._full_model_name = model_name
@@ -138,7 +155,7 @@ def _retrieve_versions(self, page_token, page_size):
138
155
def __iter__ (self ):
139
156
return iter (datalab .utils .Iterator (self ._retrieve_versions ))
140
157
141
- def get (self , version_name ):
158
+ def get_version_details (self , version_name ):
142
159
"""Get details of a version.
143
160
144
161
Args:
@@ -205,3 +222,28 @@ def delete(self, version_name):
205
222
name = ('%s/versions/%s' % (self ._full_model_name , version_name ))
206
223
response = self ._api .projects ().models ().versions ().delete (name = name ).execute ()
207
224
self ._wait_for_long_running_operation (response )
225
+
226
+ def describe (self , version_name ):
227
+ """Print information of a specified model.
228
+
229
+ Args:
230
+ version: the name of the version in short form, such as "v1".
231
+ """
232
+ version_yaml = yaml .safe_dump (self .get_version_details (version_name ),
233
+ default_flow_style = False )
234
+ print version_yaml
235
+
236
+ def list (self ):
237
+ """List versions under the current model in a table view.
238
+
239
+ Raises:
240
+ Exception if it is called in a non-IPython environment.
241
+ """
242
+ import IPython
243
+
244
+ # "self" is iterable (see __iter__() method).
245
+ data = [{'name' : version ['name' ].split ()[- 1 ],
246
+ 'deploymentUri' : version ['deploymentUri' ], 'createTime' : version ['createTime' ]}
247
+ for version in self ]
248
+ IPython .display .display (
249
+ datalab .utils .commands .render_dictionary (data , ['name' , 'deploymentUri' , 'createTime' ]))
0 commit comments