@@ -76,6 +76,7 @@ def parse_arguments(argv):
76
76
action = 'store_false' ,
77
77
help = 'Don\' t shard files' )
78
78
parser .set_defaults (shard_files = True )
79
+
79
80
parser .add_argument ('--output_format' ,
80
81
choices = ['csv' , 'json' ],
81
82
default = 'csv' ,
@@ -104,55 +105,6 @@ def parse_arguments(argv):
104
105
return args
105
106
106
107
107
- class FixMissingTarget (beam .DoFn ):
108
- """A DoFn to fix missing target columns."""
109
-
110
- def __init__ (self , trained_model_dir ):
111
- """Reads the schema file and extracted the expected number of columns.
112
-
113
- Args:
114
- trained_model_dir: path to model.
115
-
116
- Raises:
117
- ValueError: if schema.json not found in trained_model_dir
118
- """
119
- from tensorflow .python .lib .io import file_io
120
- import json
121
- import os
122
-
123
- schema_path = os .path .join (trained_model_dir , 'schema.json' )
124
- if not file_io .file_exists (schema_path ):
125
- raise ValueError ('schema.json missing from %s' % schema_path )
126
- schema = json .loads (file_io .read_file_to_string (schema_path ))
127
- self ._num_expected_columns = len (schema )
128
-
129
- def process (self , element ):
130
- """Fixes csv line if target is missing.
131
-
132
- The first column is assumed to be the target column, and the TF graph
133
- expects to always parse the target column, even in prediction. Below,
134
- we check how many csv columns there are, and if the target is missing, we
135
- prepend a ',' to denote the missing column.
136
-
137
- Example:
138
- 'target,key,value1,...' -> 'target,key,value1,...' (no change)
139
- 'key,value1,...' -> ',key,value1,...' (add missing target column)
140
-
141
- The value of the missing target column comes from the default value given
142
- to tf.decode_csv in the graph.
143
- """
144
- import apache_beam as beam
145
-
146
- num_columns = len (element .split (',' ))
147
- if num_columns == self ._num_expected_columns :
148
- yield element
149
- elif num_columns + 1 == self ._num_expected_columns :
150
- yield ',' + element
151
- else :
152
- yield beam .pvalue .SideOutputValue ('errors' ,
153
- ('bad columns' , element ))
154
-
155
-
156
108
class EmitAsBatchDoFn (beam .DoFn ):
157
109
"""A DoFn that buffers the records and emits them batch by batch."""
158
110
@@ -185,22 +137,22 @@ def __init__(self, trained_model_dir):
185
137
self ._session = None
186
138
187
139
def start_bundle (self , element = None ):
188
- from tensorflow .contrib .session_bundle import session_bundle
140
+ from tensorflow .python .saved_model import tag_constants
141
+ from tensorflow .contrib .session_bundle import bundle_shim
189
142
import json
190
143
191
- self ._session , _ = session_bundle .load_session_bundle_from_path (
192
- self ._trained_model_dir )
193
-
194
- # input_alias_map {'input_csv_string': tensor_name}
195
- self ._input_alias_map = json .loads (
196
- self ._session .graph .get_collection ('inputs' )[0 ])
197
-
198
- # output_alias_map {'target_from_input': tensor_name, 'key': ...}
199
- self ._output_alias_map = json .loads (
200
- self ._session .graph .get_collection ('outputs' )[0 ])
144
+ self ._session , meta_graph = bundle_shim .load_session_bundle_or_saved_model_bundle_from_path (self ._trained_model_dir , tags = [tag_constants .SERVING ])
145
+ signature = meta_graph .signature_def ['serving_default' ]
201
146
147
+ # get the mappings between aliases and tensor names
148
+ # for both inputs and outputs
149
+ self ._input_alias_map = {friendly_name : tensor_info_proto .name
150
+ for (friendly_name , tensor_info_proto ) in signature .inputs .items () }
151
+ self ._output_alias_map = {friendly_name : tensor_info_proto .name
152
+ for (friendly_name , tensor_info_proto ) in signature .outputs .items () }
202
153
self ._aliases , self ._tensor_names = zip (* self ._output_alias_map .items ())
203
154
155
+
204
156
def finish_bundle (self , element = None ):
205
157
self ._session .close ()
206
158
@@ -220,6 +172,11 @@ def process(self, element):
220
172
221
173
feed_dict = collections .defaultdict (list )
222
174
for line in element :
175
+
176
+ # Remove trailing newline.
177
+ if line .endswith ('\n ' ):
178
+ line = line [:- 1 ]
179
+
223
180
feed_dict [self ._input_alias_map .values ()[0 ]].append (line )
224
181
num_in_batch += 1
225
182
@@ -311,26 +268,41 @@ def __init__(self, args):
311
268
self ._output_format = args .output_format
312
269
self ._output_dir = args .output_dir
313
270
314
- # See if the target vocab should be loaded .
271
+ # Get the BQ schema if csv .
315
272
if self ._output_format == 'csv' :
316
- from tensorflow .contrib .session_bundle import session_bundle
317
- import json
318
-
319
- self ._session , _ = session_bundle .load_session_bundle_from_path (
320
- args .trained_model_dir )
321
-
322
- # output_alias_map {'target_from_input': tensor_name, 'key': ...}
323
- output_alias_map = json .loads (
324
- self ._session .graph .get_collection ('outputs' )[0 ])
325
-
326
- self ._header = sorted (output_alias_map .keys ())
327
- self ._session .close ()
328
-
273
+ from tensorflow .python .saved_model import tag_constants
274
+ from tensorflow .contrib .session_bundle import bundle_shim
275
+ from tensorflow .core .framework import types_pb2
276
+
277
+ session , meta_graph = bundle_shim .load_session_bundle_or_saved_model_bundle_from_path (args .trained_model_dir , tags = [tag_constants .SERVING ])
278
+ signature = meta_graph .signature_def ['serving_default' ]
279
+
280
+ self ._schema = []
281
+ for friendly_name in sorted (signature .outputs ):
282
+ tensor_info_proto = signature .outputs [friendly_name ]
283
+
284
+ # TODO(brandondutra): Could dtype be DT_INVALID?
285
+ # Consider getting the dtype from the graph via
286
+ # session.graph.get_tensor_by_name(tensor_info_proto.name).dtype)
287
+ dtype = tensor_info_proto .dtype
288
+ if dtype == types_pb2 .DT_FLOAT or dtype == types_pb2 .DT_DOUBLE :
289
+ bq_type == 'FLOAT'
290
+ elif dtype == types_pb2 .DT_INT32 or dtype == types_pb2 .DT_INT64 :
291
+ bq_type == 'INTEGER'
292
+ else :
293
+ bq_type = 'STRING'
294
+
295
+ self ._schema .append ({'mode' : 'NULLABLE' ,
296
+ 'name' : friendly_name ,
297
+ 'type' : bq_type })
298
+ session .close ()
329
299
330
300
def apply (self , datasets ):
331
301
return self .expand (datasets )
332
302
333
303
def expand (self , datasets ):
304
+ import json
305
+
334
306
tf_graph_predictions , errors = datasets
335
307
336
308
if self ._output_format == 'json' :
@@ -344,15 +316,16 @@ def expand(self, datasets):
344
316
shard_name_template = self ._shard_name_template ))
345
317
elif self ._output_format == 'csv' :
346
318
# make a csv header file
347
- csv_coder = CSVCoder (self ._header )
319
+ header = [col ['name' ] for col in self ._schema ]
320
+ csv_coder = CSVCoder (header )
348
321
_ = (
349
322
tf_graph_predictions .pipeline
350
323
| 'Make CSV Header'
351
- >> beam .Create ([csv_coder . make_header_string ( )])
352
- | 'Write CSV Header File'
324
+ >> beam .Create ([json . dumps ( self . _schema , indent = 2 )])
325
+ | 'Write CSV Schema File'
353
326
>> beam .io .textio .WriteToText (
354
327
os .path .join (self ._output_dir , 'csv_header' ),
355
- file_name_suffix = '.txt ' ,
328
+ file_name_suffix = '.json ' ,
356
329
shard_name_template = '' ))
357
330
358
331
# Write the csv predictions
@@ -387,15 +360,11 @@ def make_prediction_pipeline(pipeline, args):
387
360
pipeline: the pipeline
388
361
args: command line args
389
362
"""
390
-
391
-
392
363
predicted_values , errors = (
393
364
pipeline
394
365
| 'Read CSV Files'
395
366
>> beam .io .ReadFromText (args .predict_data ,
396
367
strip_trailing_newlines = True )
397
- | 'Is Target Missing'
398
- >> beam .ParDo (FixMissingTarget (args .trained_model_dir ))
399
368
| 'Batch Input'
400
369
>> beam .ParDo (EmitAsBatchDoFn (args .batch_size ))
401
370
| 'Run TF Graph on Batches'
0 commit comments