Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit bfff3bc

Browse files
authored
sd training supports python 3 (#320)
* sw * added e2e test * raise logging level * test only python 2 python 3 support for SD is close I think, but I'll continue splitting this into smaller PRs. * flake8 * add dataflow to test build, but not setup.py * install dataflow in python 2.7 test only * remove a print * added local analyze test to python3 * sw * sd training supports python 3 * flake8 * flake * io.string() -> stringIO.stringIO for python 2 * use my wrapper
1 parent 7fe89c2 commit bfff3bc

File tree

7 files changed

+62
-29
lines changed

7 files changed

+62
-29
lines changed

solutionbox/structured_data/mltoolbox/_structured_data/_package.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import tempfile
3636
import json
3737
import glob
38+
import six
3839
import subprocess
3940
import pandas as pd
4041
from tensorflow.python.lib.io import file_io
@@ -430,8 +431,13 @@ def _get_abs_path(input_path):
430431

431432
while p.poll() is None:
432433
line = p.stdout.readline()
434+
435+
if not six.PY2:
436+
line = line.decode()
437+
433438
if (line.startswith('INFO:tensorflow:global') or line.startswith('INFO:tensorflow:loss') or
434439
line.startswith('INFO:tensorflow:Saving dict')):
440+
435441
sys.stdout.write(line)
436442
finally:
437443
if monitor_process:

solutionbox/structured_data/mltoolbox/_structured_data/preprocess/cloud_preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from __future__ import print_function
1818

1919
import argparse
20-
import io
2120
import json
2221
import os
22+
import six
2323
import sys
2424

2525

@@ -212,7 +212,7 @@ def run_categorical_analysis(table, schema_list, args):
212212
df = query.execute().result().to_dataframe()
213213

214214
# Write the results to a file.
215-
string_buff = io.StringIO()
215+
string_buff = six.StringIO()
216216
df.to_csv(string_buff, index=False, header=False)
217217
file_io.write_string_to_file(out_file, string_buff.getvalue())
218218

solutionbox/structured_data/mltoolbox/_structured_data/trainer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import task
15+
from __future__ import absolute_import
16+
17+
from . import task
1618

1719
__all__ = ['task']

solutionbox/structured_data/mltoolbox/_structured_data/trainer/util.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import multiprocessing
1818
import os
1919
import math
20+
import six
2021

2122
import tensorflow as tf
2223
from tensorflow.python.lib.io import file_io
@@ -65,21 +66,16 @@ class NotFittedError(ValueError):
6566
# ==============================================================================
6667

6768

68-
def _copy_all(src_files, dest_dir):
69-
# file_io.copy does not copy files into folders directly.
70-
for src_file in src_files:
71-
file_name = os.path.basename(src_file)
72-
new_file_location = os.path.join(dest_dir, file_name)
73-
file_io.copy(src_file, new_file_location, overwrite=True)
74-
75-
7669
def _recursive_copy(src_dir, dest_dir):
7770
"""Copy the contents of src_dir into the folder dest_dir.
7871
Args:
7972
src_dir: gsc or local path.
8073
dest_dir: gcs or local path.
8174
When called, dest_dir should exist.
8275
"""
76+
src_dir = python_portable_string(src_dir)
77+
dest_dir = python_portable_string(dest_dir)
78+
8379
file_io.recursive_create_dir(dest_dir)
8480
for file_name in file_io.list_directory(src_dir):
8581
old_path = os.path.join(src_dir, file_name)
@@ -252,7 +248,9 @@ def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None
252248
gfile.Copy(source, dest_absolute)
253249

254250
# only keep the last 3 models
255-
saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3)
251+
saved_model_export_utils.garbage_collect_exports(
252+
python_portable_string(export_dir_base),
253+
exports_to_keep=3)
256254

257255
# save the last model to the model folder.
258256
# export_dir_base = A/B/intermediate_models/
@@ -482,7 +480,8 @@ def preprocess_input(features, target, train_config, preprocess_output_dir,
482480
(NUMERICAL_ANALYSIS, preprocess_output_dir))
483481

484482
numerical_anlysis = json.loads(
485-
file_io.read_file_to_string(numerical_analysis_file))
483+
python_portable_string(
484+
file_io.read_file_to_string(numerical_analysis_file)))
486485

487486
for name in train_config['numerical_columns']:
488487
if name == target_name or name == key_name:
@@ -671,7 +670,8 @@ def get_vocabulary(preprocess_output_dir, name):
671670
raise ValueError('File %s not found in %s' %
672671
(CATEGORICAL_ANALYSIS % name, preprocess_output_dir))
673672

674-
labels = file_io.read_file_to_string(vocab_file).split('\n')
673+
labels = python_portable_string(
674+
file_io.read_file_to_string(vocab_file)).split('\n')
675675
label_values = [x for x in labels if x] # remove empty lines
676676

677677
return label_values
@@ -709,10 +709,13 @@ def merge_metadata(preprocess_output_dir, transforms_file):
709709
NUMERICAL_ANALYSIS)
710710
schema_file = os.path.join(preprocess_output_dir, SCHEMA_FILE)
711711

712-
numerical_anlysis = json.loads(file_io.read_file_to_string(
713-
numerical_anlysis_file))
714-
schema = json.loads(file_io.read_file_to_string(schema_file))
715-
transforms = json.loads(file_io.read_file_to_string(transforms_file))
712+
numerical_anlysis = json.loads(
713+
python_portable_string(
714+
file_io.read_file_to_string(numerical_anlysis_file)))
715+
schema = json.loads(
716+
python_portable_string(file_io.read_file_to_string(schema_file)))
717+
transforms = json.loads(
718+
python_portable_string(file_io.read_file_to_string(transforms_file)))
716719

717720
result_dict = {}
718721
result_dict['csv_header'] = [col_schema['name'] for col_schema in schema]
@@ -725,7 +728,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
725728
result_dict['vocab_stats'] = {}
726729

727730
# get key column.
728-
for name, trans_config in transforms.iteritems():
731+
for name, trans_config in six.iteritems(transforms):
729732
if trans_config.get('transform', None) == 'key':
730733
result_dict['key_column'] = name
731734
break
@@ -734,7 +737,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
734737

735738
# get target column.
736739
result_dict['target_column'] = schema[0]['name']
737-
for name, trans_config in transforms.iteritems():
740+
for name, trans_config in six.iteritems(transforms):
738741
if trans_config.get('transform', None) == 'target':
739742
result_dict['target_column'] = name
740743
break
@@ -756,7 +759,7 @@ def merge_metadata(preprocess_output_dir, transforms_file):
756759
raise ValueError('Unsupported schema type %s' % col_type)
757760

758761
# Get the transforms.
759-
for name, trans_config in transforms.iteritems():
762+
for name, trans_config in six.iteritems(transforms):
760763
if name != result_dict['target_column'] and name != result_dict['key_column']:
761764
result_dict['transforms'][name] = trans_config
762765

@@ -849,3 +852,22 @@ def is_regression_model(model_type):
849852

850853
def is_classification_model(model_type):
851854
return model_type.endswith('_classification')
855+
856+
857+
# Note that this function exists in google.datalab.utils, but that is not
858+
# installed on the training workers.
859+
def python_portable_string(string, encoding='utf-8'):
860+
"""Converts bytes into a string type.
861+
862+
Valid string types are retuned without modification. So in Python 2, type str
863+
and unicode are not converted.
864+
865+
In Python 3, type bytes is converted to type str (unicode)
866+
"""
867+
if isinstance(string, six.string_types):
868+
return string
869+
870+
if six.PY3:
871+
return string.decode(encoding)
872+
873+
raise ValueError('Unsupported type %s' % str(type(string)))

solutionbox/structured_data/test_mltoolbox/e2e_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import random
1919
import json
20+
import six
2021
import subprocess
2122

2223

@@ -202,6 +203,10 @@ def run_training(
202203
logger.debug('Going to run command: %s' % ' '.join(cmd))
203204
sp = subprocess.Popen(' '.join(cmd), shell=True, stderr=subprocess.PIPE)
204205
_, err = sp.communicate()
206+
207+
if not six.PY2:
208+
err = err.decode()
209+
205210
return err
206211

207212

solutionbox/structured_data/test_mltoolbox/test_datalab_e2e.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,10 @@ def test_e2e(self):
178178
try:
179179
self._make_test_files()
180180
self._run_analyze()
181+
self._run_train()
181182
if six.PY2:
182-
self._run_train()
183+
# Dataflow is only supported by python 2. Prediction assumes Dataflow
184+
# is installed.
183185
self._run_predict()
184186
self._run_batch_prediction(
185187
os.path.join(self._batch_predict_output, 'with_target'),

tests/main.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
import kernel.html_tests
3939
import kernel.storage_tests
4040
import kernel.utils_tests
41-
import mltoolbox_structured_data.traininglib_tests
42-
import mltoolbox_structured_data.sd_e2e_tests
4341
import mltoolbox_structured_data.dl_interface_tests
42+
import mltoolbox_structured_data.sd_e2e_tests
43+
import mltoolbox_structured_data.traininglib_tests
4444
import stackdriver.commands.monitoring_tests
4545
import stackdriver.monitoring.group_tests
4646
import stackdriver.monitoring.metric_tests
@@ -78,6 +78,7 @@
7878
kernel.utils_tests,
7979
mltoolbox_structured_data.dl_interface_tests,
8080
mltoolbox_structured_data.sd_e2e_tests, # Not everything runs in Python 3.
81+
mltoolbox_structured_data.traininglib_tests,
8182
stackdriver.commands.monitoring_tests,
8283
stackdriver.monitoring.group_tests,
8384
stackdriver.monitoring.metric_tests,
@@ -93,11 +94,6 @@
9394
_util.util_tests
9495
]
9596

96-
# mltoolbox is not part of the datalab install, but it should still be tested.
97-
# mltoolbox does not work with python 3.
98-
if sys.version_info.major == 2:
99-
_TEST_MODULES.append(mltoolbox_structured_data.traininglib_tests)
100-
10197
if __name__ == '__main__':
10298
suite = unittest.TestSuite()
10399
for m in _TEST_MODULES:

0 commit comments

Comments
 (0)