13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
16
+
17
+ import os
16
18
import random
17
19
import subprocess
18
- import os
20
+
19
21
20
22
def make_csv_data (filename , num_rows , problem_type ):
21
23
random .seed (12321 )
@@ -29,9 +31,9 @@ def make_csv_data(filename, num_rows, problem_type):
29
31
str2 = random .choice (['abc' , 'def' , 'ghi' , 'jkl' , 'mno' , 'pqr' ])
30
32
str3 = random .choice (['car' , 'truck' , 'van' , 'bike' , 'train' , 'drone' ])
31
33
32
- map1 = {'red' :2 , 'blue' :6 , 'green' :4 , 'pink' :- 5 , 'yellow' :- 6 , 'brown' :- 1 , 'black' :7 }
33
- map2 = {'abc' :10 , 'def' :1 , 'ghi' :1 , 'jkl' :1 , 'mno' :1 , 'pqr' :1 }
34
- map3 = {'car' :5 , 'truck' :10 , 'van' :15 , 'bike' :20 , 'train' :25 , 'drone' : 30 }
34
+ map1 = {'red' : 2 , 'blue' : 6 , 'green' : 4 , 'pink' : - 5 , 'yellow' : - 6 , 'brown' : - 1 , 'black' : 7 }
35
+ map2 = {'abc' : 10 , 'def' : 1 , 'ghi' : 1 , 'jkl' : 1 , 'mno' : 1 , 'pqr' : 1 }
36
+ map3 = {'car' : 5 , 'truck' : 10 , 'van' : 15 , 'bike' : 20 , 'train' : 25 , 'drone' : 30 }
35
37
36
38
# Build some model.
37
39
t = 0.5 + 0.5 * num1 - 2.5 * num2 + num3
@@ -56,67 +58,83 @@ def make_csv_data(filename, num_rows, problem_type):
56
58
str3 = str3 )
57
59
f1 .write (csv_line )
58
60
59
- config = {'column_names' : ['key' , 'target' , 'num1' , 'num2' , 'num3' ,
61
+ schema = {'column_names' : ['key' , 'target' , 'num1' , 'num2' , 'num3' ,
60
62
'str1' , 'str2' , 'str3' ],
61
63
'key_column' : 'key' ,
62
64
'target_column' : 'target' ,
63
- 'problem_type' : problem_type ,
64
- 'model_type' : '' ,
65
- 'numerical' : {'num1' : {'transform' : 'identity' },
66
- 'num2' : {'transform' : 'identity' },
67
- 'num3' : {'transform' : 'identity' }},
68
- 'categorical' : {'str1' : {'transform' : 'one_hot' },
69
- 'str2' : {'transform' : 'one_hot' },
70
- 'str3' : {'transform' : 'one_hot' }}
65
+ 'numerical_columns' : ['num1' , 'num2' , 'num3' ],
66
+ 'categorical_columns' : ['str1' , 'str2' , 'str3' ]
71
67
}
72
- return config
73
-
68
+ if problem_type == 'classification' :
69
+ schema ['categorical_columns' ] += ['target' ]
70
+ else :
71
+ schema ['numerical_columns' ] += ['target' ]
74
72
73
+ # use defaults for num3 and str3
74
+ transforms = {'num1' : {'transform' : 'identity' },
75
+ 'num2' : {'transform' : 'identity' },
76
+ # 'num3': {'transform': 'identity'},
77
+ 'str1' : {'transform' : 'one_hot' },
78
+ 'str2' : {'transform' : 'one_hot' },
79
+ # 'str3': {'transform': 'one_hot'}
80
+ }
81
+ return schema , transforms
75
82
76
83
77
- def run_preprocess (output_dir , csv_filename , config_filename ,
84
+ def run_preprocess (output_dir , csv_filename , schema_filename ,
78
85
train_percent = '80' , eval_percent = '10' , test_percent = '10' ):
79
- cmd = ['python' , './preprocess/preprocess.py' ,
86
+ preprocess_script = os .path .abspath (
87
+ os .path .join (os .path .dirname (__file__ ), '../preprocess/preprocess.py' ))
88
+ cmd = ['python' , preprocess_script ,
80
89
'--output_dir' , output_dir ,
81
- '--input_file_path' , csv_filename ,
82
- '--transforms_config_file ' , config_filename ,
90
+ '--input_file_path' , csv_filename ,
91
+ '--schema_file ' , schema_filename ,
83
92
'--train_percent' , train_percent ,
84
93
'--eval_percent' , eval_percent ,
85
94
'--test_percent' , test_percent ,
86
95
]
87
- print ('Current working directoyr: %s' % os .getcwd ())
88
96
print ('Going to run command: %s' % ' ' .join (cmd ))
89
97
subprocess .check_call (cmd , stderr = open (os .devnull , 'wb' ))
90
98
91
- def run_training (output_dir , input_dir , config_filename , extra_args = []):
92
- """Runs Training via gcloud alpha ml local train.
99
+
100
+ def run_training (output_dir , input_dir , schema_filename , transforms_filename ,
101
+ max_steps , extra_args = []):
102
+ """Runs Training via gcloud beta ml local train.
93
103
94
104
Args:
95
105
output_dir: the trainer's output folder
96
- input_folder : should contain features_train*, features_eval*, and
106
+ input_dir : should contain features_train*, features_eval*, and
97
107
mmetadata.json.
98
- config_filename: path to the config file
108
+ schema_filename: path to the schema file
109
+ transforms_filename: path to the transforms file.
110
+ max_steps: int. max training steps.
99
111
extra_args: array of strings, passed to the trainer.
112
+
113
+ Returns:
114
+ The stderr of training as one string. TF writes to stderr, so basically, the
115
+ output of training.
100
116
"""
101
117
train_filename = os .path .join (input_dir , 'features_train*' )
102
118
eval_filename = os .path .join (input_dir , 'features_eval*' )
103
119
metadata_filename = os .path .join (input_dir , 'metadata.json' )
104
- cmd = ['gcloud alpha ml local train' ,
120
+
121
+ # Gcloud has the fun bug that you have to be in the parent folder of task.py
122
+ # when you call it. So cd there first.
123
+ task_parent_folder = os .path .abspath (
124
+ os .path .join (os .path .dirname (__file__ ), '..' ))
125
+ cmd = ['cd %s &&' % task_parent_folder ,
126
+ 'gcloud beta ml local train' ,
105
127
'--module-name=trainer.task' ,
106
128
'--package-path=trainer' ,
107
129
'--' ,
108
130
'--train_data_paths=%s' % train_filename ,
109
131
'--eval_data_paths=%s' % eval_filename ,
110
132
'--metadata_path=%s' % metadata_filename ,
111
133
'--output_path=%s' % output_dir ,
112
- '--transforms_config_file =%s' % config_filename ,
113
- '--max_steps=2500' ] + extra_args
114
- print ( 'Current working directoyr: %s' % os . getcwd ())
134
+ '--schema_file =%s' % schema_filename ,
135
+ '--transforms_file=%s' % transforms_filename ,
136
+ '--max_steps= %s' % max_steps ] + extra_args
115
137
print ('Going to run command: %s' % ' ' .join (cmd ))
116
- sp = subprocess .Popen (' ' .join (cmd ), shell = True , stderr = subprocess .PIPE ) #open(os.devnull, 'wb'))
138
+ sp = subprocess .Popen (' ' .join (cmd ), shell = True , stderr = subprocess .PIPE )
117
139
_ , err = sp .communicate ()
118
- err = err .splitlines ()
119
- print 'last line'
120
- print err [len (err )- 1 ]
121
-
122
- stderr = subprocess .PIPE
140
+ return err
0 commit comments