17
17
import multiprocessing
18
18
import os
19
19
import math
20
+ import six
20
21
21
22
import tensorflow as tf
22
23
from tensorflow .python .lib .io import file_io
@@ -65,21 +66,16 @@ class NotFittedError(ValueError):
65
66
# ==============================================================================
66
67
67
68
68
- def _copy_all (src_files , dest_dir ):
69
- # file_io.copy does not copy files into folders directly.
70
- for src_file in src_files :
71
- file_name = os .path .basename (src_file )
72
- new_file_location = os .path .join (dest_dir , file_name )
73
- file_io .copy (src_file , new_file_location , overwrite = True )
74
-
75
-
76
69
def _recursive_copy (src_dir , dest_dir ):
77
70
"""Copy the contents of src_dir into the folder dest_dir.
78
71
Args:
79
72
src_dir: gsc or local path.
80
73
dest_dir: gcs or local path.
81
74
When called, dest_dir should exist.
82
75
"""
76
+ src_dir = python_portable_string (src_dir )
77
+ dest_dir = python_portable_string (dest_dir )
78
+
83
79
file_io .recursive_create_dir (dest_dir )
84
80
for file_name in file_io .list_directory (src_dir ):
85
81
old_path = os .path .join (src_dir , file_name )
@@ -252,7 +248,9 @@ def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None
252
248
gfile .Copy (source , dest_absolute )
253
249
254
250
# only keep the last 3 models
255
- saved_model_export_utils .garbage_collect_exports (export_dir_base , exports_to_keep = 3 )
251
+ saved_model_export_utils .garbage_collect_exports (
252
+ python_portable_string (export_dir_base ),
253
+ exports_to_keep = 3 )
256
254
257
255
# save the last model to the model folder.
258
256
# export_dir_base = A/B/intermediate_models/
@@ -482,7 +480,8 @@ def preprocess_input(features, target, train_config, preprocess_output_dir,
482
480
(NUMERICAL_ANALYSIS , preprocess_output_dir ))
483
481
484
482
numerical_anlysis = json .loads (
485
- file_io .read_file_to_string (numerical_analysis_file ))
483
+ python_portable_string (
484
+ file_io .read_file_to_string (numerical_analysis_file )))
486
485
487
486
for name in train_config ['numerical_columns' ]:
488
487
if name == target_name or name == key_name :
@@ -671,7 +670,8 @@ def get_vocabulary(preprocess_output_dir, name):
671
670
raise ValueError ('File %s not found in %s' %
672
671
(CATEGORICAL_ANALYSIS % name , preprocess_output_dir ))
673
672
674
- labels = file_io .read_file_to_string (vocab_file ).split ('\n ' )
673
+ labels = python_portable_string (
674
+ file_io .read_file_to_string (vocab_file )).split ('\n ' )
675
675
label_values = [x for x in labels if x ] # remove empty lines
676
676
677
677
return label_values
@@ -709,10 +709,13 @@ def merge_metadata(preprocess_output_dir, transforms_file):
709
709
NUMERICAL_ANALYSIS )
710
710
schema_file = os .path .join (preprocess_output_dir , SCHEMA_FILE )
711
711
712
- numerical_anlysis = json .loads (file_io .read_file_to_string (
713
- numerical_anlysis_file ))
714
- schema = json .loads (file_io .read_file_to_string (schema_file ))
715
- transforms = json .loads (file_io .read_file_to_string (transforms_file ))
712
+ numerical_anlysis = json .loads (
713
+ python_portable_string (
714
+ file_io .read_file_to_string (numerical_anlysis_file )))
715
+ schema = json .loads (
716
+ python_portable_string (file_io .read_file_to_string (schema_file )))
717
+ transforms = json .loads (
718
+ python_portable_string (file_io .read_file_to_string (transforms_file )))
716
719
717
720
result_dict = {}
718
721
result_dict ['csv_header' ] = [col_schema ['name' ] for col_schema in schema ]
@@ -725,7 +728,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
725
728
result_dict ['vocab_stats' ] = {}
726
729
727
730
# get key column.
728
- for name , trans_config in transforms .iteritems ():
731
+ for name , trans_config in six .iteritems (transforms ):
729
732
if trans_config .get ('transform' , None ) == 'key' :
730
733
result_dict ['key_column' ] = name
731
734
break
@@ -734,7 +737,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
734
737
735
738
# get target column.
736
739
result_dict ['target_column' ] = schema [0 ]['name' ]
737
- for name , trans_config in transforms .iteritems ():
740
+ for name , trans_config in six .iteritems (transforms ):
738
741
if trans_config .get ('transform' , None ) == 'target' :
739
742
result_dict ['target_column' ] = name
740
743
break
@@ -756,7 +759,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
756
759
raise ValueError ('Unsupported schema type %s' % col_type )
757
760
758
761
# Get the transforms.
759
- for name , trans_config in transforms .iteritems ():
762
+ for name , trans_config in six .iteritems (transforms ):
760
763
if name != result_dict ['target_column' ] and name != result_dict ['key_column' ]:
761
764
result_dict ['transforms' ][name ] = trans_config
762
765
@@ -849,3 +852,22 @@ def is_regression_model(model_type):
849
852
850
853
def is_classification_model (model_type ):
851
854
return model_type .endswith ('_classification' )
855
+
856
+
857
+ # Note that this function exists in google.datalab.utils, but that is not
858
+ # installed on the training workers.
859
+ def python_portable_string (string , encoding = 'utf-8' ):
860
+ """Converts bytes into a string type.
861
+
862
+ Valid string types are retuned without modification. So in Python 2, type str
863
+ and unicode are not converted.
864
+
865
+ In Python 3, type bytes is converted to type str (unicode)
866
+ """
867
+ if isinstance (string , six .string_types ):
868
+ return string
869
+
870
+ if six .PY3 :
871
+ return string .decode (encoding )
872
+
873
+ raise ValueError ('Unsupported type %s' % str (type (string )))
0 commit comments