37
37
import tempfile
38
38
import urllib
39
39
import json
40
+ import glob
41
+ import StringIO
40
42
43
+ import pandas as pd
41
44
import tensorflow as tf
42
45
import yaml
43
46
44
47
from . import preprocess
45
48
from . import trainer
49
+ from . import predict
46
50
47
51
_TF_GS_URL = 'gs://cloud-datalab/deploy/tf/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl'
48
52
@@ -112,7 +116,8 @@ def local_preprocess(output_dir, input_feature_file, input_file_pattern, schema_
112
116
def cloud_preprocess (output_dir , input_feature_file , input_file_pattern = None , schema_file = None , bigquery_table = None , project_id = None ):
113
117
"""Preprocess data in the cloud with BigQuery.
114
118
115
- Produce analysis used by training.
119
+ Produce analysis used by training. This can take a while, even for small
120
+ datasets. For small datasets, it may be faster to use local_preprocess.
116
121
117
122
Args:
118
123
output_dir: The output directory to use.
@@ -133,13 +138,15 @@ def cloud_preprocess(output_dir, input_feature_file, input_file_pattern=None, sc
133
138
args .append ('--input_file_pattern=%s' % input_file_pattern )
134
139
if schema_file :
135
140
args .append ('--schema_file=%s' % schema_file )
141
+ if not project_id :
142
+ project_id = _default_project ()
136
143
if bigquery_table :
137
- if not project_id :
138
- project_id = _default_project ()
139
144
full_name = project_id + ':' + bigquery_table
140
145
args .append ('--bigquery_table=%s' % full_name )
141
146
142
147
print ('Starting cloud preprocessing.' )
148
+ print ('Track BigQuery status at' )
149
+ print ('https://bigquery.cloud.google.com/queries/%s' % project_id )
143
150
preprocess .cloud_preprocess .main (args )
144
151
print ('Cloud preprocessing done.' )
145
152
@@ -151,6 +158,7 @@ def local_train(train_file_pattern,
151
158
transforms_file ,
152
159
model_type ,
153
160
max_steps ,
161
+ top_n = None ,
154
162
layer_sizes = None ):
155
163
"""Train model locally.
156
164
Args:
@@ -161,6 +169,9 @@ def local_train(train_file_pattern,
161
169
transforms_file: File path to the transforms file.
162
170
model_type: model type
163
171
max_steps: Int. Number of training steps to perform.
172
+ top_n: Int. For classification problems, the output graph will contain the
173
+ labels and scores for the top n classes with a default of n=1. Use
174
+ None for regression problems.
164
175
layer_sizes: List. Represents the layers in the connected DNN.
165
176
If the model type is DNN, this must be set. Example [10, 3, 2], this
166
177
will create three DNN layers where the first layer will have 10 nodes,
@@ -180,6 +191,8 @@ def local_train(train_file_pattern,
180
191
'--max_steps=%s' % str (max_steps )]
181
192
if layer_sizes :
182
193
args .extend (['--layer_sizes' ] + [str (x ) for x in layer_sizes ])
194
+ if top_n :
195
+ args .append ('--top_n=%s' % str (top_n ))
183
196
184
197
print ('Starting local training.' )
185
198
trainer .task .main (args )
@@ -192,6 +205,7 @@ def cloud_train(train_file_pattern,
192
205
transforms_file ,
193
206
model_type ,
194
207
max_steps ,
208
+ top_n = None ,
195
209
layer_sizes = None ,
196
210
staging_bucket = None ,
197
211
project_id = None ,
@@ -208,6 +222,9 @@ def cloud_train(train_file_pattern,
208
222
transforms_file: File path to the transforms file.
209
223
model_type: model type
210
224
max_steps: Int. Number of training steps to perform.
225
+ top_n: Int. For classification problems, the output graph will contain the
226
+ labels and scores for the top n classes with a default of n=1.
227
+ Use None for regression problems.
211
228
layer_sizes: List. Represents the layers in the connected DNN.
212
229
If the model type is DNN, this must be set. Example [10, 3, 2], this
213
230
will create three DNN layers where the first layer will have 10 nodes,
@@ -238,6 +255,8 @@ def cloud_train(train_file_pattern,
238
255
'--max_steps=%s' % str (max_steps )]
239
256
if layer_sizes :
240
257
args .extend (['--layer_sizes' ] + [str (x ) for x in layer_sizes ])
258
+ if top_n :
259
+ args .append ('--top_n=%s' % str (top_n ))
241
260
242
261
# TODO(brandondutra): move these package uris locally, ask for a staging
243
262
# and copy them there. This package should work without cloudml having to
@@ -277,30 +296,89 @@ def cloud_train(train_file_pattern,
277
296
print (job_request )
278
297
279
298
280
- def local_predict ():
299
+ def local_predict (model_dir , data ):
281
300
"""Runs local prediction.
282
301
283
- Runs local prediction in memory and prints the results to the screen . For
302
+ Runs local prediction and returns the result in a Pandas DataFrame . For
284
303
running prediction on a large dataset or saving the results, run
285
304
local_batch_prediction or batch_prediction.
286
305
287
306
Args:
288
-
307
+ model_dir: local path to the trained mode. Usually, this is
308
+ training_output_dir/model.
309
+ data: List of csv strings that match the model schema. Or a pandas DataFrame
310
+ where the columns match the model schema. The first column,
311
+ the target column, could be missing.
289
312
"""
290
313
# Save the instances to a file, call local batch prediction, and print it back
314
+ tmp_dir = tempfile .mkdtemp ()
315
+ _ , input_file_path = tempfile .mkstemp (dir = tmp_dir , suffix = '.csv' ,
316
+ prefix = 'input' )
291
317
292
-
293
-
294
-
295
- def cloud_predict ():
318
+ try :
319
+ if isinstance (data , pd .DataFrame ):
320
+ data .to_csv (input_file_path , header = False , index = False )
321
+ else :
322
+ with open (input_file_path , 'w' ) as f :
323
+ for line in data :
324
+ f .write (line + '\n ' )
325
+
326
+
327
+ cmd = ['predict.py' ,
328
+ '--predict_data=%s' % input_file_path ,
329
+ '--trained_model_dir=%s' % model_dir ,
330
+ '--output_dir=%s' % tmp_dir ,
331
+ '--output_format=csv' ,
332
+ '--batch_size=100' ,
333
+ '--no-shard_files' ]
334
+
335
+ print ('Starting local prediction.' )
336
+ predict .predict .main (cmd )
337
+ print ('Local prediction done.' )
338
+
339
+ # Read the header file.
340
+ with open (os .path .join (tmp_dir , 'csv_header.txt' ), 'r' ) as f :
341
+ header = f .readline ()
342
+
343
+ # Print any errors to the screen.
344
+ errors_file = glob .glob (os .path .join (tmp_dir , 'errors*' ))
345
+ if errors_file and os .path .getsize (errors_file [0 ]) > 0 :
346
+ print ('Warning: there are errors. See below:' )
347
+ with open (errors_file [0 ], 'r' ) as f :
348
+ text = f .read ()
349
+ print (text )
350
+
351
+ # Read the predictions data.
352
+ prediction_file = glob .glob (os .path .join (tmp_dir , 'predictions*' ))
353
+ if not prediction_file :
354
+ raise FileNotFoundError ('Prediction results not found' )
355
+ predictions = pd .read_csv (prediction_file [0 ],
356
+ header = None ,
357
+ names = header .split (',' ))
358
+ return predictions
359
+ finally :
360
+ shutil .rmtree (tmp_dir )
361
+
362
+
363
+ def cloud_predict (model_name , model_version , data , is_target_missing = False ):
296
364
"""Use Online prediction.
297
365
298
366
Runs online prediction in the cloud and prints the results to the screen. For
299
367
running prediction on a large dataset or saving the results, run
300
368
local_batch_prediction or batch_prediction.
301
369
302
370
Args:
303
-
371
+ model_name: deployed model name
372
+ model_verion: depoyed model version
373
+ data: List of csv strings that match the model schema. Or a pandas DataFrame
374
+ where the columns match the model schema. The first column,
375
+ the target column, is assumed to exist in the data.
376
+ is_target_missing: If true, prepends a ',' in each csv string or adds an
377
+ empty DataFrame column. If the csv data has a leading ',' keep this flag
378
+ False. Example:
379
+ 1) If data = ['target,input1,input2'], then set is_target_missing=False.
380
+ 2) If data = [',input1,input2'], then set is_target_missing=False.
381
+ 3) If data = ['input1,input2'], then set is_target_missing=True.
304
382
305
383
Before using this, the model must be created. This can be done by running
306
384
two gcloud commands:
@@ -313,12 +391,38 @@ def cloud_predict():
313
391
--project=PROJECT
314
392
Note that the model must be on GCS.
315
393
"""
316
- pass
394
+ import datalab .mlalpha as mlalpha
395
+
396
+
397
+ if isinstance (data , pd .DataFrame ):
398
+ # write the df to csv.
399
+ string_buffer = StringIO .StringIO ()
400
+ data .to_csv (string_buffer , header = None , index = False )
401
+ csv_lines = string_buffer .getvalue ().split ('\n ' )
402
+
403
+ if is_target_missing :
404
+ input_data = [',' + csv for csv in csv_lines ]
405
+ else :
406
+ input_data = csv_lines
407
+ else :
408
+ if is_target_missing :
409
+ input_data = [ ',' + csv for csv in data ]
410
+ else :
411
+ input_data = data
412
+
413
+ cloud_predictor = mlalpha .CloudPredictor (model_name , model_version )
414
+ predictions = cloud_predictor .predict (input_data )
317
415
416
+ # Convert predictions into a dataframe
417
+ df = pd .DataFrame (columns = sorted (predictions [0 ].keys ()))
418
+ for i in range (len (predictions )):
419
+ for k , v in predictions [i ].iteritems ():
420
+ df .loc [i , k ] = v
421
+ return df
318
422
319
423
320
424
def local_batch_predict (model_dir , prediction_input_file , output_dir ,
321
- batch_size = 1000 , shard_files = True ):
425
+ batch_size = 1000 , shard_files = True , output_format = 'csv' ):
322
426
"""Local batch prediction.
323
427
324
428
Args:
@@ -329,12 +433,13 @@ def local_batch_predict(model_dir, prediction_input_file, output_dir,
329
433
batch_size: Int. How many instances to run in memory at once. Larger values
330
434
mean better performace but more memeory consumed.
331
435
shard_files: If false, the output files are not shardded.
436
+ output_format: csv or json. Json file are json-newlined.
332
437
"""
333
438
cmd = ['predict.py' ,
334
439
'--predict_data=%s' % prediction_input_file ,
335
440
'--trained_model_dir=%s' % model_dir ,
336
441
'--output_dir=%s' % output_dir ,
337
- '--output_format=csv' ,
442
+ '--output_format=%s' % output_format ,
338
443
'--batch_size=%s' % str (batch_size )]
339
444
340
445
if shard_files :
@@ -343,13 +448,13 @@ def local_batch_predict(model_dir, prediction_input_file, output_dir,
343
448
cmd .append ('--no-shard_files' )
344
449
345
450
print ('Starting local batch prediction.' )
346
- predict .predict .main (args )
451
+ predict .predict .main (cmd )
347
452
print ('Local batch prediction done.' )
348
453
349
454
350
455
351
456
def cloud_batch_predict (model_dir , prediction_input_file , output_dir ,
352
- batch_size = 1000 , shard_files = True ):
457
+ batch_size = 1000 , shard_files = True , output_format = 'csv' ):
353
458
"""Cloud batch prediction. Submitts a Dataflow job.
354
459
355
460
Args:
@@ -360,14 +465,15 @@ def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
360
465
batch_size: Int. How many instances to run in memory at once. Larger values
361
466
mean better performace but more memeory consumed.
362
467
shard_files: If false, the output files are not shardded.
468
+ output_format: csv or json. Json file are json-newlined.
363
469
"""
364
470
cmd = ['predict.py' ,
365
471
'--cloud' ,
366
472
'--project_id=%s' % _default_project (),
367
473
'--predict_data=%s' % prediction_input_file ,
368
474
'--trained_model_dir=%s' % model_dir ,
369
475
'--output_dir=%s' % output_dir ,
370
- '--output_format=csv' ,
476
+ '--output_format=%s' % output_format ,
371
477
'--batch_size=%s' % str (batch_size )]
372
478
373
479
if shard_files :
@@ -376,5 +482,5 @@ def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
376
482
cmd .append ('--no-shard_files' )
377
483
378
484
print ('Starting cloud batch prediction.' )
379
- predict .predict .main (args )
485
+ predict .predict .main (cmd )
380
486
print ('See above link for job status.' )
0 commit comments