14
14
15
15
from googleapiclient import discovery
16
16
import os
17
- import time
18
17
import yaml
19
18
20
19
import datalab .context
21
20
import datalab .storage
22
21
import datalab .utils
23
22
23
+ from . import _util
24
24
25
- # TODO(qimingj) Remove once the API is public since it will no longer be needed
26
- _CLOUDML_DISCOVERY_URL = 'https://storage.googleapis.com/cloud-ml/discovery/' \
27
- 'ml_v1beta1_discovery.json'
28
-
29
-
30
- class CloudModels (object ):
25
+ class Models (object ):
31
26
"""Represents a list of Cloud ML models for a project."""
32
27
33
28
def __init__ (self , project_id = None ):
34
- """Initializes an instance of a CloudML Model list that is iteratable
35
- ("for model in CloudModels()").
36
-
29
+ """
37
30
Args:
38
31
project_id: project_id of the models. If not provided, default project_id will be used.
39
32
"""
40
33
if project_id is None :
41
34
project_id = datalab .context .Context .default ().project_id
42
35
self ._project_id = project_id
43
36
self ._credentials = datalab .context .Context .default ().credentials
44
- self ._api = discovery .build ('ml' , 'v1alpha3' , credentials = self ._credentials ,
45
- discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
37
+ self ._api = discovery .build ('ml' , 'v1' , credentials = self ._credentials )
46
38
47
39
def _retrieve_models (self , page_token , page_size ):
48
40
list_info = self ._api .projects ().models ().list (
@@ -51,11 +43,13 @@ def _retrieve_models(self, page_token, page_size):
51
43
page_token = list_info .get ('nextPageToken' , None )
52
44
return models , page_token
53
45
54
- def __iter__ (self ):
46
+ def get_iterator (self ):
47
+ """Get iterator of models so it can be used as "for model in Models().get_iterator()".
48
+ """
55
49
return iter (datalab .utils .Iterator (self ._retrieve_models ))
56
50
57
51
def get_model_details (self , model_name ):
58
- """Get details of a model.
52
+ """Get details of the specified model from CloudML Service .
59
53
60
54
Args:
61
55
model_name: the name of the model. It can be a model full name
@@ -72,10 +66,16 @@ def create(self, model_name):
72
66
73
67
Args:
74
68
model_name: the short name of the model, such as "iris".
69
+ Returns:
70
+ If successful, returns informaiton of the model, such as
71
+ {u'regions': [u'us-central1'], u'name': u'projects/myproject/models/mymodel'}
72
+ Raises:
73
+ If the model creation failed.
75
74
"""
76
75
body = {'name' : model_name }
77
76
parent = 'projects/' + self ._project_id
78
- self ._api .projects ().models ().create (body = body , parent = parent ).execute ()
77
+ # Model creation is instant. If anything goes wrong, Exception will be thrown.
78
+ return self ._api .projects ().models ().create (body = body , parent = parent ).execute ()
79
79
80
80
def delete (self , model_name ):
81
81
"""Delete a model.
@@ -87,7 +87,10 @@ def delete(self, model_name):
87
87
full_name = model_name
88
88
if not model_name .startswith ('projects/' ):
89
89
full_name = ('projects/%s/models/%s' % (self ._project_id , model_name ))
90
- return self ._api .projects ().models ().delete (name = full_name ).execute ()
90
+ response = self ._api .projects ().models ().delete (name = full_name ).execute ()
91
+ if 'name' not in response :
92
+ raise Exception ('Invalid response from service. "name" is not found.' )
93
+ _util .wait_for_long_running_operation (response ['name' ])
91
94
92
95
def list (self , count = 10 ):
93
96
"""List models under the current project in a table view.
@@ -121,13 +124,11 @@ def describe(self, model_name):
121
124
print model_yaml
122
125
123
126
124
- class CloudModelVersions (object ):
127
+ class ModelVersions (object ):
125
128
"""Represents a list of versions for a Cloud ML model."""
126
129
127
130
def __init__ (self , model_name , project_id = None ):
128
- """Initializes an instance of a CloudML model version list that is iteratable
129
- ("for version in CloudModelVersions()").
130
-
131
+ """
131
132
Args:
132
133
model_name: the name of the model. It can be a model full name
133
134
("projects/[project_id]/models/[model_name]") or just [model_name].
@@ -137,8 +138,7 @@ def __init__(self, model_name, project_id=None):
137
138
if project_id is None :
138
139
self ._project_id = datalab .context .Context .default ().project_id
139
140
self ._credentials = datalab .context .Context .default ().credentials
140
- self ._api = discovery .build ('ml' , 'v1alpha3' , credentials = self ._credentials ,
141
- discoveryServiceUrl = _CLOUDML_DISCOVERY_URL )
141
+ self ._api = discovery .build ('ml' , 'v1' , credentials = self ._credentials )
142
142
if not model_name .startswith ('projects/' ):
143
143
model_name = ('projects/%s/models/%s' % (self ._project_id , model_name ))
144
144
self ._full_model_name = model_name
@@ -152,7 +152,10 @@ def _retrieve_versions(self, page_token, page_size):
152
152
page_token = list_info .get ('nextPageToken' , None )
153
153
return versions , page_token
154
154
155
- def __iter__ (self ):
155
+ def get_iterator (self ):
156
+ """Get iterator of versions so it can be used as
157
+ "for v in ModelVersions(model_name).get_iterator()".
158
+ """
156
159
return iter (datalab .utils .Iterator (self ._retrieve_versions ))
157
160
158
161
def get_version_details (self , version_name ):
@@ -165,21 +168,6 @@ def get_version_details(self, version_name):
165
168
name = ('%s/versions/%s' % (self ._full_model_name , version_name ))
166
169
return self ._api .projects ().models ().versions ().get (name = name ).execute ()
167
170
168
- def _wait_for_long_running_operation (self , response ):
169
- if 'name' not in response :
170
- raise Exception ('Invaid response from service. Cannot find "name" field.' )
171
- print ('Waiting for job "%s"' % response ['name' ])
172
- while True :
173
- response = self ._api .projects ().operations ().get (name = response ['name' ]).execute ()
174
- if 'done' not in response or response ['done' ] != True :
175
- time .sleep (3 )
176
- else :
177
- if 'error' in response :
178
- print (response ['error' ])
179
- else :
180
- print ('Done.' )
181
- break
182
-
183
171
def deploy (self , version_name , path ):
184
172
"""Deploy a model version to the cloud.
185
173
@@ -211,7 +199,9 @@ def deploy(self, version_name, path):
211
199
}
212
200
response = self ._api .projects ().models ().versions ().create (body = body ,
213
201
parent = self ._full_model_name ).execute ()
214
- self ._wait_for_long_running_operation (response )
202
+ if 'name' not in response :
203
+ raise Exception ('Invalid response from service. "name" is not found.' )
204
+ _util .wait_for_long_running_operation (response ['name' ])
215
205
216
206
def delete (self , version_name ):
217
207
"""Delete a version of model.
@@ -221,8 +211,37 @@ def delete(self, version_name):
221
211
"""
222
212
name = ('%s/versions/%s' % (self ._full_model_name , version_name ))
223
213
response = self ._api .projects ().models ().versions ().delete (name = name ).execute ()
224
- self ._wait_for_long_running_operation (response )
225
-
214
+ if 'name' not in response :
215
+ raise Exception ('Invalid response from service. "name" is not found.' )
216
+ _util .wait_for_long_running_operation (response ['name' ])
217
+
218
+ def predict (self , version_name , data ):
219
+ """Get prediction results from features instances.
220
+
221
+ Args:
222
+ version_name: the name of the version used for prediction.
223
+ data: typically a list of instance to be submitted for prediction. The format of the
224
+ instance depends on the model. For example, structured data model may require
225
+ a csv line for each instance.
226
+ Note that online prediction only works on models that take one placeholder value,
227
+ such as a string encoding a csv line.
228
+ Returns:
229
+ A list of prediction results for given instances. Each element is a dictionary representing
230
+ output mapping from the graph.
231
+ An example:
232
+ [{"predictions": 1, "score": [0.00078, 0.71406, 0.28515]},
233
+ {"predictions": 1, "score": [0.00244, 0.99634, 0.00121]}]
234
+ """
235
+ full_version_name = ('%s/versions/%s' % (self ._full_model_name , version_name ))
236
+ request = self ._api .projects ().predict (body = {'instances' : data },
237
+ name = full_version_name )
238
+ request .headers ['user-agent' ] = 'GoogleCloudDataLab/1.0'
239
+ result = request .execute ()
240
+ if 'predictions' not in result :
241
+ raise Exception ('Invalid response from service. Cannot find "predictions" in response.' )
242
+
243
+ return result ['predictions' ]
244
+
226
245
def describe (self , version_name ):
227
246
"""Print information of a specified model.
228
247
0 commit comments