@@ -277,50 +277,30 @@ def cloud_train(train_file_pattern,
277
277
print (job_request )
278
278
279
279
280
- def local_predict (model_dir , prediction_input_file ):
280
+ def local_predict ():
281
281
"""Runs local prediction.
282
282
283
283
Runs local prediction in memory and prints the results to the screen. For
284
284
running prediction on a large dataset or saving the results, run
285
285
local_batch_prediction or batch_prediction.
286
286
287
287
Args:
288
- model_dir: Path to folder that contains the model. This is usully OUT/model
289
- where OUT is the value of output_dir when local_training was ran.
290
- prediction_input_file: csv file that has the same schem as the input
291
- files used during local_preprocess, except that the target column is
292
- removed.
288
+
293
289
"""
294
- pass
295
- #TODO(brandondutra): remove this hack once cloudml 1.8 is released.
296
- # Check that the model folder has a metadata.yaml file. If not, copy it.
297
- # if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
298
- # shutil.copy2(os.path.join(model_dir, 'metadata.json'),
299
- # os.path.join(model_dir, 'metadata.yaml'))
290
+ # Save the instances to a file, call local batch prediction, and print it back
291
+
300
292
301
- # cmd = ['gcloud beta ml local predict',
302
- # '--model-dir=%s' % model_dir,
303
- # '--text-instances=%s' % prediction_input_file]
304
- # print('Local prediction, running command: %s' % ' '.join(cmd))
305
- # _run_cmd(' '.join(cmd))
306
- # print('Local prediction done.')
307
293
308
294
309
- def cloud_predict (model_name , prediction_input_file , version_name = None ):
295
+ def cloud_predict ():
310
296
"""Use Online prediction.
311
297
312
298
Runs online prediction in the cloud and prints the results to the screen. For
313
299
running prediction on a large dataset or saving the results, run
314
300
local_batch_prediction or batch_prediction.
315
301
316
302
Args:
317
- model_dir: Path to folder that contains the model. This is usully OUT/model
318
- where OUT is the value of output_dir when local_training was ran.
319
- prediction_input_file: csv file that has the same schem as the input
320
- files used during local_preprocess, except that the target column is
321
- removed.
322
- vsersion_name: Optional version of the model to use. If None, the default
323
- version is used.
303
+
324
304
325
305
Before using this, the model must be created. This can be done by running
326
306
two gcloud commands:
@@ -334,91 +314,67 @@ def cloud_predict(model_name, prediction_input_file, version_name=None):
334
314
Note that the model must be on GCS.
335
315
"""
336
316
pass
337
- # cmd = ['gcloud beta ml predict',
338
- # '--model=%s' % model_name,
339
- # '--text-instances=%s' % prediction_input_file]
340
- # if version_name:
341
- # cmd += ['--version=%s' % version_name]
342
317
343
- # print('CloudML online prediction, running command: %s' % ' '.join(cmd))
344
- # _run_cmd(' '.join(cmd))
345
- # print('CloudML online prediction done.')
346
318
347
319
348
- def local_batch_predict (model_dir , prediction_input_file , output_dir ):
320
+ def local_batch_predict (model_dir , prediction_input_file , output_dir ,
321
+ batch_size = 1000 , shard_files = True ):
349
322
"""Local batch prediction.
350
323
351
324
Args:
352
- model_dir: local path to trained model.
353
- prediction_input_file: File path to input files. May contain a file pattern.
354
- Only csv files are supported, and the scema must match what was used
355
- in preprocessing except that the target column is removed.
356
- output_dir: folder to save results to.
325
+ model_dir: local file path to trained model. Usually, this is
326
+ training_output_dir/model.
327
+ prediction_input_file: csv file pattern to a local file.
328
+ output_dir: local output location to save the results.
329
+ batch_size: Int. How many instances to run in memory at once. Larger values
330
+ mean better performace but more memeory consumed.
331
+ shard_files: If false, the output files are not shardded.
357
332
"""
358
- pass
359
- #TODO(brandondutra): remove this hack once cloudml 1.8 is released.
360
- # Check that the model folder has a metadata.yaml file. If not, copy it.
361
- # if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
362
- # shutil.copy2(os.path.join(model_dir, 'metadata.json'),
363
- # os.path.join(model_dir, 'metadata.yaml'))
333
+ cmd = ['predict.py' ,
334
+ '--predict_data=%s' % prediction_input_file ,
335
+ '--trained_model_dir=%s' % model_dir ,
336
+ '--output_dir=%s' % output_dir ,
337
+ '--output_format=csv' ,
338
+ '--batch_size=%s' % str (batch_size )]
339
+
340
+ if shard_files :
341
+ cmd .append ('--shard_files' )
342
+ else :
343
+ cmd .append ('--no-shard_files' )
364
344
365
- # cmd = ['python -m google.cloud.ml.dataflow.batch_prediction_main',
366
- # '--input_file_format=text',
367
- # '--input_file_patterns=%s' % prediction_input_file,
368
- # '--output_location=%s' % output_dir,
369
- # '--model_dir=%s' % model_dir]
345
+ print ('Starting local batch prediction.' )
346
+ predict .predict .main (args )
347
+ print ('Local batch prediction done.' )
370
348
371
- # print('Local batch prediction, running command: %s' % ' '.join(cmd))
372
- # _run_cmd(' '.join(cmd))
373
- # print('Local batch prediction done.')
374
349
375
350
376
- def cloud_batch_predict (model_name , prediction_input_file , output_dir , region ,
377
- job_name = None , version_name = None ):
378
- """Cloud batch prediction.
351
+ def cloud_batch_predict (model_dir , prediction_input_file , output_dir ,
352
+ batch_size = 1000 , shard_files = True ):
353
+ """Cloud batch prediction. Submitts a Dataflow job.
379
354
380
355
Args:
381
- model_name: name of the model. The model must already exist.
382
- prediction_input_file: File path to input files. May contain a file pattern.
383
- Only csv files are supported, and the scema must match what was used
384
- in preprocessing except that the target column is removed. Files must
385
- be on GCS
386
- output_dir: GCS folder to safe results to.
387
- region: GCP compute region to run the batch job. Try using your default
388
- region first, as this cloud batch prediction is not avaliable in all
389
- regions.
390
- job_name: job name used for the cloud job.
391
- version_name: model version to use. If node, the default version of the
392
- model is used.
356
+ model_dir: GSC file path to trained model. Usually, this is
357
+ training_output_dir/model.
358
+ prediction_input_file: csv file pattern to a GSC file.
359
+ output_dir: Location to save the results on GCS.
360
+ batch_size: Int. How many instances to run in memory at once. Larger values
361
+ mean better performace but more memeory consumed.
362
+ shard_files: If false, the output files are not shardded.
393
363
"""
394
- pass
395
- # job_name = job_name or ('structured_data_batch_predict_' +
396
- # datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
397
-
398
- # if (not prediction_input_file.startswith('gs://') or
399
- # not output_dir.startswith('gs://')):
400
- # print('ERROR: prediction_input_file and output_dir must point to a '
401
- # 'location on GCS.')
402
- # return
403
-
404
- # cmd = ['gcloud beta ml jobs submit prediction %s' % job_name,
405
- # '--model=%s' % model_name,
406
- # '--region=%s' % region,
407
- # '--data-format=TEXT',
408
- # '--input-paths=%s' % prediction_input_file,
409
- # '--output-path=%s' % output_dir]
410
- # if version_name:
411
- # cmd += ['--version=%s' % version_name]
412
-
413
- # print('CloudML batch prediction, running command: %s' % ' '.join(cmd))
414
- # _run_cmd(' '.join(cmd))
415
- # print('CloudML batch prediction job submitted.')
416
-
417
- # if _is_in_IPython():
418
- # import IPython
419
-
420
- # dataflow_url = ('https://console.developers.google.com/ml/jobs?project=%s'
421
- # % _default_project())
422
- # html = ('<p>Click <a href="%s" target="_blank">here</a> to track '
423
- # 'the prediction job %s.</p><br/>' % (dataflow_url, job_name))
424
- # IPython.display.display_html(html, raw=True)
364
+ cmd = ['predict.py' ,
365
+ '--cloud' ,
366
+ '--project_id=%s' % _default_project (),
367
+ '--predict_data=%s' % prediction_input_file ,
368
+ '--trained_model_dir=%s' % model_dir ,
369
+ '--output_dir=%s' % output_dir ,
370
+ '--output_format=csv' ,
371
+ '--batch_size=%s' % str (batch_size )]
372
+
373
+ if shard_files :
374
+ cmd .append ('--shard_files' )
375
+ else :
376
+ cmd .append ('--no-shard_files' )
377
+
378
+ print ('Starting cloud batch prediction.' )
379
+ predict .predict .main (args )
380
+ print ('See above link for job status.' )
0 commit comments