Skip to content

Commit 4ea9d25

Browse files
committed
Changed cross_validation to model_selection due to deprecation
1 parent 9230bb1 commit 4ea9d25

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

net_mk1.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright 2016 The TF Codelab Contributors. All Rights Reserved.
22
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
33
# http://www.apache.org/licenses/LICENSE-2.0
4-
#
4+
#
55
# This code was originaflly presented at GDGSpain DevFest
66
# using character prediction from Tensorflow
77
# https://github.com/bigpress/gameofthrones/blob/master/character-predictions.csv
88
#
99
# Latest version is always available at: https://github.com/codelab-tf-got/code/
1010
# Codelab test is available at: https://codelab-tf-cot.github.io
11-
# Codelab code by @ssice . Front @SoyGema
11+
# Codelab code by @ssice . Front @SoyGema
1212
# ==============================================================================
1313

1414
"""Import Python 2-3 compatibility glue, ETL (pandas) and ML (TensorFlow/sklearn) libraries"""
@@ -26,7 +26,7 @@
2626
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
2727
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
2828

29-
from sklearn import cross_validation # to split the train/test cases
29+
from sklearn import model_selection # to split the train/test cases
3030

3131

3232
## Uncomment the logging lines to see logs in the console
@@ -164,7 +164,7 @@ def get_dataset(filename, local_path='../dataset'):
164164
preset_deep_columns = []
165165

166166
def get_deep_columns():
167-
"""Obtains the deep columns of the model.
167+
"""Obtains the deep columns of the model.
168168
169169
In our model, these are the binary columns (which are embedded with
170170
keys "0" and "1") and the categorical columns, which are embedded as
@@ -223,7 +223,7 @@ def get_wide_columns():
223223
##############################################################################
224224
def build_estimator(model_dir):
225225
"""General estimator builder function.
226-
226+
227227
The wide/deep part construction is below. This gathers both parts
228228
and joins the model into a single classifier.
229229
@@ -349,7 +349,7 @@ def _experiment_fn(output_dir):
349349
)
350350
return experiment
351351
return _experiment_fn
352-
352+
353353

354354
def fill_dataframe(df_base):
355355
"""
@@ -378,7 +378,7 @@ def train_and_eval(job_dir=None):
378378
df_base[LABEL_COLUMN] = (
379379
df_base[LABEL_COLUMN].apply(lambda x: x)).astype(int)
380380

381-
df_train, df_test = cross_validation.train_test_split(df_base, test_size=0.2, random_state=42)
381+
df_train, df_test = model_selection.train_test_split(df_base, test_size=0.2, random_state=42)
382382

383383
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
384384
print("model directory = %s" % model_dir)

0 commit comments

Comments
 (0)