Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit cef4eae

Browse files
committed
Inception Package Improvements (#186)
* Implement inception cloud batch prediction. Support explicit eval data in preprocessing. * Follow up on CR comments. Also address changes from latest DataFlow.
1 parent 299163f commit cef4eae

File tree

7 files changed

+426
-187
lines changed

7 files changed

+426
-187
lines changed

datalab/mlalpha/_dataset.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,30 @@ def __init__(self, files, schema=None, schema_file=None):
3737
schema: A BigQuery schema object in the form of
3838
[{'name': 'col1', 'type': 'STRING'},
3939
{'name': 'col2', 'type': 'INTEGER'}]
40+
or a single string in of the form 'col1:STRING,col2:INTEGER,col3:FLOAT'.
4041
schema_file: A JSON serialized schema file. If schema is None, it will try to load from
4142
schema_file if not None.
43+
Raise:
44+
ValueError if both schema and schema_file are None.
4245
"""
43-
self._schema = None
46+
if schema is None and schema_file is None:
47+
raise ValueError('schema and schema_file cannot both be None.')
48+
4449
if schema is not None:
45-
self._schema = schema
46-
elif schema_file is not None:
50+
if isinstance(schema, list):
51+
self._schema = schema
52+
else:
53+
self._schema = []
54+
for x in schema.split(','):
55+
parts = x.split(':')
56+
if len(parts) != 2:
57+
raise ValueError('invalid schema string "%s"' % x)
58+
self._schema.append({'name': parts[0].strip(), 'type': parts[1].strip()})
59+
else:
4760
with ml.util._file.open_local_or_gcs(schema_file, 'r') as f:
4861
self._schema = json.load(f)
62+
if isinstance(files, basestring):
63+
files = [files]
4964
self._files = []
5065
for file in files:
5166
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
@@ -97,28 +112,48 @@ def sample(self, n):
97112
skip = [x for x in skip_all if x < row_count]
98113
skip_all = [x - row_count for x in skip_all if x >= row_count]
99114
with ml.util._file.open_local_or_gcs(file, 'r') as f:
100-
dfs.append(pd.read_csv(file, skiprows=skip, names=names, dtype=dtype, header=None))
115+
dfs.append(pd.read_csv(f, skiprows=skip, names=names, dtype=dtype, header=None))
101116
return pd.concat(dfs, axis=0, ignore_index=True)
102117

103118

104119
class BigQueryDataSet(object):
105120
"""DataSet based on BigQuery table or query."""
106121

107-
def __init__(self, sql):
122+
def __init__(self, sql=None, table=None):
108123
"""
109124
Args:
110-
sql: Can be one of:
111-
A table name.
112-
A SQL query string.
113-
A SQL Query module defined with '%%sql --name [module_name]'
125+
sql: A SQL query string, or a SQL Query module defined with '%%sql --name [module_name]'
126+
table: A table name in the form of "dataset:table".
127+
Raises:
128+
ValueError if both sql and table are set, or both are None.
114129
"""
115-
query, _ = datalab.data.SqlModule.get_sql_statement_with_environment(sql, {})
116-
self._sql = query.sql
130+
if (sql is None and table is None) or (sql is not None and table is not None):
131+
raise ValueError('One and only one of sql and table should be set.')
132+
133+
self._query = None
134+
self._table = None
135+
if sql is not None:
136+
query, _ = datalab.data.SqlModule.get_sql_statement_with_environment(sql, {})
137+
self._query = query.sql
138+
if table is not None:
139+
self._table = table
140+
self._schema = None
117141

118142
@property
119-
def sql(self):
120-
return self._sql
121-
143+
def query(self):
144+
return self._query
145+
146+
@property
147+
def table(self):
148+
return self._table
149+
150+
@property
151+
def schema(self):
152+
if self._schema is None:
153+
source = self._query or self._table
154+
self._schema = bq.Query('SELECT * FROM (%s) LIMIT 1' % source).results().schema
155+
return self._schema
156+
122157
def sample(self, n):
123158
"""Samples data into a Pandas DataFrame. Note that it calls BigQuery so it will
124159
incur cost.
@@ -129,10 +164,11 @@ def sample(self, n):
129164
Raises:
130165
Exception if n is larger than number of rows.
131166
"""
132-
total = bq.Query('select count(*) from (%s)' % self._sql).results()[0].values()[0]
167+
source = self._query or self._table
168+
total = bq.Query('select count(*) from (%s)' % source).results()[0].values()[0]
133169
if n > total:
134170
raise ValueError('sample larger than population')
135171
sampling = bq.Sampling.random(n*100.0/float(total))
136-
sample = bq.Query(self._sql).sample(sampling=sampling)
172+
sample = bq.Query(source).sample(sampling=sampling)
137173
df = sample.to_dataframe()
138174
return df

solutionbox/inception/datalab_solutions/inception/_cloud.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
from . import _model
29+
from . import _predictor
2930
from . import _preprocess
3031
from . import _trainer
3132
from . import _util
@@ -55,7 +56,7 @@ def _repackage_to_staging(self, output_path):
5556
mlalpha.package_and_copy(package_root, _SETUP_PY, staging_package_url)
5657
return staging_package_url
5758

58-
def preprocess(self, dataset, output_dir, pipeline_option=None):
59+
def preprocess(self, train_dataset, eval_dataset, output_dir, pipeline_option):
5960
"""Cloud preprocessing with Cloud DataFlow."""
6061

6162
import datalab.mlalpha as mlalpha
@@ -76,13 +77,8 @@ def preprocess(self, dataset, output_dir, pipeline_option=None):
7677

7778
opts = beam.pipeline.PipelineOptions(flags=[], **options)
7879
p = beam.Pipeline('DataflowRunner', options=opts)
79-
if type(dataset) is mlalpha.CsvDataSet:
80-
_preprocess.configure_pipeline_csv(p, self._checkpoint, dataset.files, output_dir, job_name)
81-
elif type(dataset) is mlalpha.BigQueryDataSet:
82-
_preprocess.configure_pipeline_bigquery(p, self._checkpoint, dataset.sql,
83-
output_dir, job_name)
84-
else:
85-
raise ValueError('preprocess takes CsvDataSet or BigQueryDataset only.')
80+
_preprocess.configure_pipeline(p, train_dataset, eval_dataset, self._checkpoint,
81+
output_dir, job_name)
8682
p.run()
8783
return job_name
8884

@@ -136,3 +132,29 @@ def predict(self, model_id, image_files):
136132
labels_and_scores = [(x['prediction'], x['scores'][labels.index(x['prediction'])])
137133
for x in predictions]
138134
return labels_and_scores
135+
136+
def batch_predict(self, dataset, model_dir, gcs_staging_location, output_csv,
137+
output_bq_table, pipeline_option):
138+
"""Cloud batch prediction with a model specified by a GCS directory."""
139+
140+
import datalab.mlalpha as mlalpha
141+
142+
job_name = 'batch-predict-inception-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')
143+
staging_package_url = self._repackage_to_staging(gcs_staging_location)
144+
options = {
145+
'staging_location': os.path.join(gcs_staging_location, 'tmp', 'staging'),
146+
'temp_location': os.path.join(gcs_staging_location, 'tmp'),
147+
'job_name': job_name,
148+
'project': _util.default_project(),
149+
'extra_packages': [ml.sdk_location, staging_package_url, _TF_GS_URL],
150+
'teardown_policy': 'TEARDOWN_ALWAYS',
151+
'no_save_main_session': True
152+
}
153+
if pipeline_option is not None:
154+
options.update(pipeline_option)
155+
156+
opts = beam.pipeline.PipelineOptions(flags=[], **options)
157+
p = beam.Pipeline('DataflowRunner', options=opts)
158+
_predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table)
159+
p.run()
160+
return job_name

solutionbox/inception/datalab_solutions/inception/_local.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, checkpoint=None):
4242
if self._checkpoint is None:
4343
self._checkpoint = _util._DEFAULT_CHECKPOINT_GSURL
4444

45-
def preprocess(self, dataset, output_dir):
45+
def preprocess(self, train_dataset, eval_dataset, output_dir):
4646
"""Local preprocessing with local DataFlow."""
4747

4848
import datalab.mlalpha as mlalpha
@@ -53,12 +53,8 @@ def preprocess(self, dataset, output_dir):
5353
}
5454
opts = beam.pipeline.PipelineOptions(flags=[], **options)
5555
p = beam.Pipeline('DirectRunner', options=opts)
56-
if type(dataset) is mlalpha.CsvDataSet:
57-
_preprocess.configure_pipeline_csv(p, self._checkpoint, dataset.files, output_dir, job_id)
58-
elif type(dataset) is mlalpha.BigQueryDataSet:
59-
_preprocess.configure_pipeline_bigquery(p, self._checkpoint, dataset.sql, output_dir, job_id)
60-
else:
61-
raise ValueError('preprocess takes CsvDataSet or BigQueryDataset only.')
56+
_preprocess.configure_pipeline(p, train_dataset, eval_dataset,
57+
self._checkpoint, output_dir, job_id)
6258
p.run().wait_until_finish()
6359

6460
def train(self, input_dir, batch_size, max_steps, output_dir):
@@ -77,7 +73,15 @@ def predict(self, model_dir, image_files):
7773
return _predictor.predict(model_dir, image_files)
7874

7975

80-
def batch_predict(self, model_dir, input_csv, output_file, output_bq_table):
76+
def batch_predict(self, dataset, model_dir, output_csv, output_bq_table):
8177
"""Local batch prediction."""
82-
83-
return _predictor.batch_predict(model_dir, input_csv, output_file, output_bq_table)
78+
import datalab.mlalpha as mlalpha
79+
job_id = 'inception_batch_predict_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S')
80+
# Project is needed for bigquery data source, even in local run.
81+
options = {
82+
'project': _util.default_project(),
83+
}
84+
opts = beam.pipeline.PipelineOptions(flags=[], **options)
85+
p = beam.Pipeline('DirectRunner', options=opts)
86+
_predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table)
87+
p.run().wait_until_finish()

solutionbox/inception/datalab_solutions/inception/_package.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,39 +36,49 @@
3636
from . import _util
3737

3838

39-
def local_preprocess(dataset, output_dir, checkpoint=None):
39+
def local_preprocess(train_dataset, output_dir, checkpoint=None, eval_dataset=None):
4040
"""Preprocess data locally. Produce output that can be used by training efficiently.
4141
Args:
42-
dataset: data source to preprocess. Can be either datalab.mlalpha.CsvDataset, or
43-
datalab.mlalpha.BigQueryDataSet.
42+
train_dataset: training data source to preprocess. Can be CsvDataset or BigQueryDataSet.
43+
If eval_dataset is None, the pipeline will randomly split train_dataset into
44+
train/eval set with 7:3 ratio.
4445
output_dir: The output directory to use. Preprocessing will create a sub directory under
4546
it for each run, and also update "latest" file which points to the latest preprocessed
4647
directory. Users are responsible for cleanup. Can be local or GCS path.
4748
checkpoint: the Inception checkpoint to use.
49+
eval_dataset: evaluation data source to preprocess. Can be CsvDataset or BigQueryDataSet.
50+
If specified, it will be used for evaluation during training, and train_dataset will be
51+
completely used for training.
4852
"""
4953

5054
print 'Local preprocessing...'
5155
# TODO: Move this to a new process to avoid pickling issues
5256
# TODO: Expose train/eval split ratio
53-
_local.Local(checkpoint).preprocess(dataset, output_dir)
57+
_local.Local(checkpoint).preprocess(train_dataset, eval_dataset, output_dir)
5458
print 'Done'
5559

5660

57-
def cloud_preprocess(dataset, output_dir, checkpoint=None, pipeline_option=None):
61+
def cloud_preprocess(train_dataset, output_dir, checkpoint=None, pipeline_option=None,
62+
eval_dataset=None):
5863
"""Preprocess data in Cloud with DataFlow.
5964
Produce output that can be used by training efficiently.
6065
Args:
61-
dataset: data source to preprocess. Can be either datalab.mlalpha.CsvDataset, or
62-
datalab.mlalpha.BigQueryDataSet. For CsvDataSet, all files need to be in GCS.
66+
train_dataset: training data source to preprocess. Can be CsvDataset or BigQueryDataSet.
67+
For CsvDataSet, all files must be in GCS.
68+
If eval_dataset is None, the pipeline will randomly split train_dataset into
69+
train/eval set with 7:3 ratio.
6370
output_dir: The output directory to use. Preprocessing will create a sub directory under
6471
it for each run, and also update "latest" file which points to the latest preprocessed
6572
directory. Users are responsible for cleanup. GCS path only.
6673
checkpoint: the Inception checkpoint to use.
74+
pipeline_option: DataFlow pipeline options in a dictionary.
75+
eval_dataset: evaluation data source to preprocess. Can be CsvDataset or BigQueryDataSet.
76+
If specified, it will be used for evaluation during training, and train_dataset will be
77+
completely used for training.
6778
"""
6879

69-
# TODO: Move this to a new process to avoid pickling issues
70-
# TODO: Expose train/eval split ratio
71-
job_name = _cloud.Cloud(checkpoint=checkpoint).preprocess(dataset, output_dir, pipeline_option)
80+
job_name = _cloud.Cloud(checkpoint=checkpoint).preprocess(train_dataset, eval_dataset,
81+
output_dir, pipeline_option)
7282
if (_util.is_in_IPython()):
7383
import IPython
7484

@@ -172,19 +182,58 @@ def cloud_predict(model_id, image_files, show_image=True):
172182
_display_predict_results(results, show_image)
173183

174184

175-
def local_batch_predict(model_dir, input_csv, output_file, output_bq_table=None):
176-
"""Batch predict using an offline model.
185+
def local_batch_predict(dataset, model_dir, output_csv=None, output_bq_table=None):
186+
"""Batch predict running locally.
177187
Args:
188+
dataset: CsvDataSet or BigQueryDataSet for batch prediction input. Can contain either
189+
one column 'image_url', or two columns with another being 'label'.
178190
model_dir: The directory of a trained inception model. Can be local or GCS paths.
179-
input_csv: The input csv which include two columns only: image_gs_url, label.
180-
Can be local or GCS paths.
181-
output_file: The output csv file containing prediction results.
182-
output_bq_table: If provided, will also save the results to BigQuery table.
191+
output_csv: The output csv file for prediction results. If specified,
192+
it will also output a csv schema file with the name output_csv + '.schema.json'.
193+
output_bq_table: if specified, the output BigQuery table for prediction results.
194+
output_csv and output_bq_table can both be set.
195+
Raises:
196+
ValueError if both output_csv and output_bq_table are None.
183197
"""
198+
199+
if output_csv is None and output_bq_table is None:
200+
raise ValueError('output_csv and output_bq_table cannot both be None.')
201+
184202
print('Predicting...')
185-
_local.Local().batch_predict(model_dir, input_csv, output_file, output_bq_table)
203+
_local.Local().batch_predict(dataset, model_dir, output_csv, output_bq_table)
186204
print('Done')
187205

188-
def cloud_batch_predict(model_dir, image_files, show_image=True, output_file=None):
189-
"""Not Implemented Yet"""
190-
pass
206+
207+
def cloud_batch_predict(dataset, model_dir, gcs_staging_location,
208+
output_csv=None, output_bq_table=None, pipeline_option=None):
209+
"""Batch predict running in cloud.
210+
211+
Args:
212+
dataset: CsvDataSet or BigQueryDataSet for batch prediction input. Can contain either
213+
one column 'image_url', or two columns with another being 'label'.
214+
model_dir: A GCS path to a trained inception model directory.
215+
gcs_staging_location: A temporary location for DataFlow staging.
216+
output_csv: If specified, prediction results will be saved to the specified Csv file.
217+
It will also output a csv schema file with the name output_csv + '.schema.json'.
218+
GCS file path only.
219+
output_bq_table: If specified, prediction results will be saved to the specified BigQuery
220+
table. output_csv and output_bq_table can both be set, but cannot be both None.
221+
pipeline_option: DataFlow pipeline options in a dictionary.
222+
Raises:
223+
ValueError if both output_csv and output_bq_table are None.
224+
"""
225+
226+
if output_csv is None and output_bq_table is None:
227+
raise ValueError('output_csv and output_bq_table cannot both be None.')
228+
229+
job_name = _cloud.Cloud().batch_predict(dataset, model_dir,
230+
gcs_staging_location, output_csv, output_bq_table, pipeline_option)
231+
if (_util.is_in_IPython()):
232+
import IPython
233+
234+
dataflow_url = ('https://console.developers.google.com/dataflow?project=%s' %
235+
_util.default_project())
236+
html = 'Job "%s" submitted.' % job_name
237+
html += ('<p>Click <a href="%s" target="_blank">here</a> to track batch prediction job. <br/>'
238+
% dataflow_url)
239+
IPython.display.display_html(html, raw=True)

0 commit comments

Comments
 (0)