Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit 2cf04df

Browse files
authored
prediction update (#183)
* added the ',' graph hack * sw * batch prediction done * sw * review comments * updated the the prediction graph keys, and makde the csvcoder not need any other file. * sw * sw * added newline * review comments * review comments * trying to fix the Contributor License Agreement error.
1 parent 9bff1fb commit 2cf04df

File tree

7 files changed

+224
-118
lines changed

7 files changed

+224
-118
lines changed

solutionbox/structured_data/datalab_solutions/structured_data/_package.py

Lines changed: 124 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@
3737
import tempfile
3838
import urllib
3939
import json
40+
import glob
41+
import StringIO
4042

43+
import pandas as pd
4144
import tensorflow as tf
4245
import yaml
4346

4447
from . import preprocess
4548
from . import trainer
49+
from . import predict
4650

4751
_TF_GS_URL = 'gs://cloud-datalab/deploy/tf/tensorflow-0.12.0rc0-cp27-none-linux_x86_64.whl'
4852

@@ -112,7 +116,8 @@ def local_preprocess(output_dir, input_feature_file, input_file_pattern, schema_
112116
def cloud_preprocess(output_dir, input_feature_file, input_file_pattern=None, schema_file=None, bigquery_table=None, project_id=None):
113117
"""Preprocess data in the cloud with BigQuery.
114118
115-
Produce analysis used by training.
119+
Produce analysis used by training. This can take a while, even for small
120+
datasets. For small datasets, it may be faster to use local_preprocess.
116121
117122
Args:
118123
output_dir: The output directory to use.
@@ -133,13 +138,15 @@ def cloud_preprocess(output_dir, input_feature_file, input_file_pattern=None, sc
133138
args.append('--input_file_pattern=%s' % input_file_pattern)
134139
if schema_file:
135140
args.append('--schema_file=%s' % schema_file)
141+
if not project_id:
142+
project_id = _default_project()
136143
if bigquery_table:
137-
if not project_id:
138-
project_id = _default_project()
139144
full_name = project_id + ':' + bigquery_table
140145
args.append('--bigquery_table=%s' % full_name)
141146

142147
print('Starting cloud preprocessing.')
148+
print('Track BigQuery status at')
149+
print('https://bigquery.cloud.google.com/queries/%s' % project_id)
143150
preprocess.cloud_preprocess.main(args)
144151
print('Cloud preprocessing done.')
145152

@@ -151,6 +158,7 @@ def local_train(train_file_pattern,
151158
transforms_file,
152159
model_type,
153160
max_steps,
161+
top_n=None,
154162
layer_sizes=None):
155163
"""Train model locally.
156164
Args:
@@ -161,6 +169,9 @@ def local_train(train_file_pattern,
161169
transforms_file: File path to the transforms file.
162170
model_type: model type
163171
max_steps: Int. Number of training steps to perform.
172+
top_n: Int. For classification problems, the output graph will contain the
173+
labels and scores for the top n classes with a default of n=1. Use
174+
None for regression problems.
164175
layer_sizes: List. Represents the layers in the connected DNN.
165176
If the model type is DNN, this must be set. Example [10, 3, 2], this
166177
will create three DNN layers where the first layer will have 10 nodes,
@@ -180,6 +191,8 @@ def local_train(train_file_pattern,
180191
'--max_steps=%s' % str(max_steps)]
181192
if layer_sizes:
182193
args.extend(['--layer_sizes'] + [str(x) for x in layer_sizes])
194+
if top_n:
195+
args.append('--top_n=%s' % str(top_n))
183196

184197
print('Starting local training.')
185198
trainer.task.main(args)
@@ -192,6 +205,7 @@ def cloud_train(train_file_pattern,
192205
transforms_file,
193206
model_type,
194207
max_steps,
208+
top_n=None,
195209
layer_sizes=None,
196210
staging_bucket=None,
197211
project_id=None,
@@ -208,6 +222,9 @@ def cloud_train(train_file_pattern,
208222
transforms_file: File path to the transforms file.
209223
model_type: model type
210224
max_steps: Int. Number of training steps to perform.
225+
top_n: Int. For classification problems, the output graph will contain the
226+
labels and scores for the top n classes with a default of n=1.
227+
Use None for regression problems.
211228
layer_sizes: List. Represents the layers in the connected DNN.
212229
If the model type is DNN, this must be set. Example [10, 3, 2], this
213230
will create three DNN layers where the first layer will have 10 nodes,
@@ -238,6 +255,8 @@ def cloud_train(train_file_pattern,
238255
'--max_steps=%s' % str(max_steps)]
239256
if layer_sizes:
240257
args.extend(['--layer_sizes'] + [str(x) for x in layer_sizes])
258+
if top_n:
259+
args.append('--top_n=%s' % str(top_n))
241260

242261
# TODO(brandondutra): move these package uris locally, ask for a staging
243262
# and copy them there. This package should work without cloudml having to
@@ -277,30 +296,89 @@ def cloud_train(train_file_pattern,
277296
print(job_request)
278297

279298

280-
def local_predict():
299+
def local_predict(model_dir, data):
281300
"""Runs local prediction.
282301
283-
Runs local prediction in memory and prints the results to the screen. For
302+
Runs local prediction and returns the result in a Pandas DataFrame. For
284303
running prediction on a large dataset or saving the results, run
285304
local_batch_prediction or batch_prediction.
286305
287306
Args:
288-
307+
model_dir: local path to the trained mode. Usually, this is
308+
training_output_dir/model.
309+
data: List of csv strings that match the model schema. Or a pandas DataFrame
310+
where the columns match the model schema. The first column,
311+
the target column, could be missing.
289312
"""
290313
# Save the instances to a file, call local batch prediction, and print it back
314+
tmp_dir = tempfile.mkdtemp()
315+
_, input_file_path = tempfile.mkstemp(dir=tmp_dir, suffix='.csv',
316+
prefix='input')
291317

292-
293-
294-
295-
def cloud_predict():
318+
try:
319+
if isinstance(data, pd.DataFrame):
320+
data.to_csv(input_file_path, header=False, index=False)
321+
else:
322+
with open(input_file_path, 'w') as f:
323+
for line in data:
324+
f.write(line + '\n')
325+
326+
327+
cmd = ['predict.py',
328+
'--predict_data=%s' % input_file_path,
329+
'--trained_model_dir=%s' % model_dir,
330+
'--output_dir=%s' % tmp_dir,
331+
'--output_format=csv',
332+
'--batch_size=100',
333+
'--no-shard_files']
334+
335+
print('Starting local prediction.')
336+
predict.predict.main(cmd)
337+
print('Local prediction done.')
338+
339+
# Read the header file.
340+
with open(os.path.join(tmp_dir, 'csv_header.txt'), 'r') as f:
341+
header = f.readline()
342+
343+
# Print any errors to the screen.
344+
errors_file = glob.glob(os.path.join(tmp_dir, 'errors*'))
345+
if errors_file and os.path.getsize(errors_file[0]) > 0:
346+
print('Warning: there are errors. See below:')
347+
with open(errors_file[0], 'r') as f:
348+
text = f.read()
349+
print(text)
350+
351+
# Read the predictions data.
352+
prediction_file = glob.glob(os.path.join(tmp_dir, 'predictions*'))
353+
if not prediction_file:
354+
raise FileNotFoundError('Prediction results not found')
355+
predictions = pd.read_csv(prediction_file[0],
356+
header=None,
357+
names=header.split(','))
358+
return predictions
359+
finally:
360+
shutil.rmtree(tmp_dir)
361+
362+
363+
def cloud_predict(model_name, model_version, data, is_target_missing=False):
296364
"""Use Online prediction.
297365
298366
Runs online prediction in the cloud and prints the results to the screen. For
299367
running prediction on a large dataset or saving the results, run
300368
local_batch_prediction or batch_prediction.
301369
302370
Args:
303-
371+
model_name: deployed model name
372+
model_verion: depoyed model version
373+
data: List of csv strings that match the model schema. Or a pandas DataFrame
374+
where the columns match the model schema. The first column,
375+
the target column, is assumed to exist in the data.
376+
is_target_missing: If true, prepends a ',' in each csv string or adds an
377+
empty DataFrame column. If the csv data has a leading ',' keep this flag
378+
False. Example:
379+
1) If data = ['target,input1,input2'], then set is_target_missing=False.
380+
2) If data = [',input1,input2'], then set is_target_missing=False.
381+
3) If data = ['input1,input2'], then set is_target_missing=True.
304382
305383
Before using this, the model must be created. This can be done by running
306384
two gcloud commands:
@@ -313,12 +391,38 @@ def cloud_predict():
313391
--project=PROJECT
314392
Note that the model must be on GCS.
315393
"""
316-
pass
394+
import datalab.mlalpha as mlalpha
395+
396+
397+
if isinstance(data, pd.DataFrame):
398+
# write the df to csv.
399+
string_buffer = StringIO.StringIO()
400+
data.to_csv(string_buffer, header=None, index=False)
401+
csv_lines = string_buffer.getvalue().split('\n')
402+
403+
if is_target_missing:
404+
input_data = [',' + csv for csv in csv_lines]
405+
else:
406+
input_data = csv_lines
407+
else:
408+
if is_target_missing:
409+
input_data = [ ',' + csv for csv in data]
410+
else:
411+
input_data = data
412+
413+
cloud_predictor = mlalpha.CloudPredictor(model_name, model_version)
414+
predictions = cloud_predictor.predict(input_data)
317415

416+
# Convert predictions into a dataframe
417+
df = pd.DataFrame(columns=sorted(predictions[0].keys()))
418+
for i in range(len(predictions)):
419+
for k, v in predictions[i].iteritems():
420+
df.loc[i, k] = v
421+
return df
318422

319423

320424
def local_batch_predict(model_dir, prediction_input_file, output_dir,
321-
batch_size=1000, shard_files=True):
425+
batch_size=1000, shard_files=True, output_format='csv'):
322426
"""Local batch prediction.
323427
324428
Args:
@@ -329,12 +433,13 @@ def local_batch_predict(model_dir, prediction_input_file, output_dir,
329433
batch_size: Int. How many instances to run in memory at once. Larger values
330434
mean better performace but more memeory consumed.
331435
shard_files: If false, the output files are not shardded.
436+
output_format: csv or json. Json file are json-newlined.
332437
"""
333438
cmd = ['predict.py',
334439
'--predict_data=%s' % prediction_input_file,
335440
'--trained_model_dir=%s' % model_dir,
336441
'--output_dir=%s' % output_dir,
337-
'--output_format=csv',
442+
'--output_format=%s' % output_format,
338443
'--batch_size=%s' % str(batch_size)]
339444

340445
if shard_files:
@@ -343,13 +448,13 @@ def local_batch_predict(model_dir, prediction_input_file, output_dir,
343448
cmd.append('--no-shard_files')
344449

345450
print('Starting local batch prediction.')
346-
predict.predict.main(args)
451+
predict.predict.main(cmd)
347452
print('Local batch prediction done.')
348453

349454

350455

351456
def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
352-
batch_size=1000, shard_files=True):
457+
batch_size=1000, shard_files=True, output_format='csv'):
353458
"""Cloud batch prediction. Submitts a Dataflow job.
354459
355460
Args:
@@ -360,14 +465,15 @@ def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
360465
batch_size: Int. How many instances to run in memory at once. Larger values
361466
mean better performace but more memeory consumed.
362467
shard_files: If false, the output files are not shardded.
468+
output_format: csv or json. Json file are json-newlined.
363469
"""
364470
cmd = ['predict.py',
365471
'--cloud',
366472
'--project_id=%s' % _default_project(),
367473
'--predict_data=%s' % prediction_input_file,
368474
'--trained_model_dir=%s' % model_dir,
369475
'--output_dir=%s' % output_dir,
370-
'--output_format=csv',
476+
'--output_format=%s' % output_format,
371477
'--batch_size=%s' % str(batch_size)]
372478

373479
if shard_files:
@@ -376,5 +482,5 @@ def cloud_batch_predict(model_dir, prediction_input_file, output_dir,
376482
cmd.append('--no-shard_files')
377483

378484
print('Starting cloud batch prediction.')
379-
predict.predict.main(args)
485+
predict.predict.main(cmd)
380486
print('See above link for job status.')

solutionbox/structured_data/datalab_solutions/structured_data/predict/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import predict
15+
import predict
16+

0 commit comments

Comments
 (0)