18
18
19
19
20
20
import apache_beam as beam
21
- from apache_beam .utils .options import PipelineOptions
21
+ from apache_beam .utils .pipeline_options import PipelineOptions
22
22
import cStringIO
23
23
import csv
24
24
import google .cloud .ml as ml
35
35
slim = tf .contrib .slim
36
36
37
37
error_count = beam .Aggregator ('errorCount' )
38
- csv_rows_count = beam .Aggregator ('csvRowsCount' )
39
- labels_count = beam .Aggregator ('labelsCount' )
38
+ rows_count = beam .Aggregator ('RowsCount' )
40
39
skipped_empty_line = beam .Aggregator ('skippedEmptyLine' )
41
40
embedding_good = beam .Aggregator ('embedding_good' )
42
41
embedding_bad = beam .Aggregator ('embedding_bad' )
@@ -66,27 +65,23 @@ def process(self, context, all_labels):
66
65
self .label_to_id_map [label ] = i
67
66
68
67
# Row format is:
69
- # image_uri(,label_ids)*
70
- row = context .element
71
- if not row :
68
+ # image_uri,label_id
69
+ element = context .element
70
+ if not element :
72
71
context .aggregate_to (skipped_empty_line , 1 )
73
72
return
74
73
75
- context .aggregate_to (csv_rows_count , 1 )
76
- uri = row [ 0 ]
74
+ context .aggregate_to (rows_count , 1 )
75
+ uri = element [ 'image_url' ]
77
76
if not uri or not uri .startswith ('gs://' ):
78
77
context .aggregate_to (invalid_uri , 1 )
79
78
return
80
79
81
- # In a real-world system, you may want to provide a default id for labels
82
- # that were not in the dictionary. In this sample, we will throw an error.
83
- # This code already supports multi-label problems if you want to use it.
84
- label_ids = [self .label_to_id_map [label .strip ()] for label in row [1 :]]
85
- context .aggregate_to (labels_count , len (label_ids ))
86
-
87
- if not label_ids :
80
+ try :
81
+ label_id = self .label_to_id_map [element ['label' ].strip ()]
82
+ except KeyError :
88
83
context .aggregate_to (ignored_unlabeled_image , 1 )
89
- yield row [ 0 ], label_ids
84
+ yield uri , label_id
90
85
91
86
92
87
class ReadImageAndConvertToJpegDoFn (beam .DoFn ):
@@ -97,7 +92,7 @@ class ReadImageAndConvertToJpegDoFn(beam.DoFn):
97
92
"""
98
93
99
94
def process (self , context ):
100
- uri , label_ids = context .element
95
+ uri , label_id = context .element
101
96
102
97
try :
103
98
with ml .util ._file .open_local_or_gcs (uri , mode = 'r' ) as f :
@@ -114,7 +109,7 @@ def process(self, context):
114
109
output = cStringIO .StringIO ()
115
110
img .save (output , 'jpeg' )
116
111
image_bytes = output .getvalue ()
117
- yield uri , label_ids , image_bytes
112
+ yield uri , label_id , image_bytes
118
113
119
114
120
115
class EmbeddingsGraph (object ):
@@ -250,7 +245,7 @@ def _bytes_feature(value):
250
245
def _float_feature (value ):
251
246
return tf .train .Feature (float_list = tf .train .FloatList (value = value ))
252
247
253
- uri , label_ids , image_bytes = context .element
248
+ uri , label_id , image_bytes = context .element
254
249
255
250
try :
256
251
embedding = self .preprocess_graph .calculate_embedding (image_bytes )
@@ -265,13 +260,11 @@ def _float_feature(value):
265
260
context .aggregate_to (embedding_bad , 1 )
266
261
267
262
example = tf .train .Example (features = tf .train .Features (feature = {
268
- 'image_uri' : _bytes_feature ([uri ]),
263
+ 'image_uri' : _bytes_feature ([str ( uri ) ]),
269
264
'embedding' : _float_feature (embedding .ravel ().tolist ()),
270
265
}))
271
266
272
- if label_ids :
273
- label_ids .sort ()
274
- example .features .feature ['label' ].int64_list .value .extend (label_ids )
267
+ example .features .feature ['label' ].int64_list .value .append (label_id )
275
268
276
269
yield example
277
270
@@ -283,20 +276,31 @@ def partition_for(self, context, num_partitions):
283
276
return 1 if random .random () > 0.7 else 0
284
277
285
278
286
- def configure_pipeline (p , checkpoint_path , input_paths , output_dir , job_id ):
287
- """Specify PCollection and transformations in pipeline."""
288
- output_latest_file = os .path .join (output_dir , 'latest' )
279
+ def _get_sources_from_csvs (p , input_paths ):
289
280
source_list = []
290
281
for ii , input_path in enumerate (input_paths ):
291
- input_source = beam .io .TextFileSource (input_path , strip_trailing_newlines = True )
292
- source_list .append (p | 'Read input %d' % ii >> beam .Read (input_source ))
293
- all_sources = source_list | 'Flatten Sources' >> beam .Flatten ()
294
- labels = (all_sources
295
- | 'Parse input for labels' >> beam .Map (lambda line : csv .reader ([line ]).next ()[1 ])
282
+ source_list .append (p | 'Read from Csv %d' % ii >>
283
+ beam .io .ReadFromText (input_path , strip_trailing_newlines = True ))
284
+ all_sources = (source_list | 'Flatten Sources' >> beam .Flatten ()
285
+ | beam .Map (lambda line : csv .DictReader ([line ], fieldnames = ['image_url' , 'label' ]).next ()))
286
+ return all_sources
287
+
288
+
289
+ def _get_sources_from_bigquery (p , query ):
290
+ if len (query .split ()) == 1 :
291
+ bq_source = beam .io .BigQuerySource (table = query )
292
+ else :
293
+ bq_source = beam .io .BigQuerySource (query = query )
294
+ query_results = p | 'Read from BigQuery' >> beam .io .Read (bq_source )
295
+ return query_results
296
+
297
+
298
+ def _configure_pipeline_from_source (source , checkpoint_path , output_dir , job_id ):
299
+ labels = (source
300
+ | 'Parse input for labels' >> beam .Map (lambda x : x ['label' ])
296
301
| 'Combine labels' >> beam .transforms .combiners .Count .PerElement ()
297
302
| 'Get labels' >> beam .Map (lambda label_count : label_count [0 ]))
298
- all_preprocessed = (all_sources
299
- | 'Parse input' >> beam .Map (lambda line : csv .reader ([line ]).next ())
303
+ all_preprocessed = (source
300
304
| 'Extract label ids' >> beam .ParDo (ExtractLabelIdsDoFn (),
301
305
beam .pvalue .AsIter (labels ))
302
306
| 'Read and convert to JPEG' >> beam .ParDo (ReadImageAndConvertToJpegDoFn ())
@@ -311,8 +315,19 @@ def configure_pipeline(p, checkpoint_path, input_paths, output_dir, job_id):
311
315
eval_save = train_eval [1 ] | 'Save eval to disk' >> SaveFeatures (preprocessed_eval )
312
316
train_save = train_eval [0 ] | 'Save train to disk' >> SaveFeatures (preprocessed_train )
313
317
# Make sure we write "latest" file after train and eval data are successfully written.
318
+ output_latest_file = os .path .join (output_dir , 'latest' )
314
319
([eval_save , train_save , labels_save ] | 'Wait for train eval saving' >> beam .Flatten () |
315
320
beam .transforms .combiners .Sample .FixedSizeGlobally ('Fixed One' , 1 ) |
316
321
beam .Map (lambda path : job_id ) |
317
322
'WriteLatest' >> beam .io .textio .WriteToText (output_latest_file , shard_name_template = '' ))
318
323
324
+
325
+ def configure_pipeline_csv (p , checkpoint_path , input_paths , output_dir , job_id ):
326
+ all_sources = _get_sources_from_csvs (p , input_paths )
327
+ _configure_pipeline_from_source (all_sources , checkpoint_path , output_dir , job_id )
328
+
329
+
330
+ def configure_pipeline_bigquery (p , checkpoint_path , query , output_dir , job_id ):
331
+ all_sources = _get_sources_from_bigquery (p , query )
332
+ _configure_pipeline_from_source (all_sources , checkpoint_path , output_dir , job_id )
333
+
0 commit comments