Skip to content

Commit 641cb25

Browse files
ilblackdragonVijay Vasudevan
authored and
Vijay Vasudevan
committed
Refactor skflow
* Switching trainer to use layers/optimizers * Updated documentation for optimizer argument and added support for instance of sub-class of optimizer if user wants to provide optimizer with custom parameters. * Added support for passing learning_rate Tensor * Make trainer work with layers/optimize_loss * Move config_addong.py to estimators and rename ConfigAddon to RunConfig * Added size=small to all tests * Moved tf_master and tf_random_seed into RunConfig * Removing need in Trainer class - replacing it with just a function and creating optimizer in the setup_training function. * Moved saver params into RunConfig * Moved creation of placeholders inside data feeder. Can be used for multi inputs/outputs. * switch to use platform.default._gfile for file operations * Adding a build rule for all skflow examples * correcting build rule and import in neural translation example * Replacing one_hot_matrix by tf.one_hot (speed + GPU) * Fixing missing test and left over refactoring in montior.py * Fixed iris custom config example. Also removed duplication of names in trainer. * rename iris_config_addon to iris_run_config * Fix keep_prob->dropout arg rename in dnn * Removing .orig files * Bump test_early_stopping test to medium
1 parent 4b4eafa commit 641cb25

29 files changed

+555
-495
lines changed

tensorflow/contrib/layers/python/layers/optimizers.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,16 @@ class should be sub-class of tf.Optimizer that implements
8787
loss = control_flow_ops.with_dependencies([loss_averages_op], loss)
8888

8989
# Learning rate variable, with possible decay.
90-
lr = vs.get_variable("learning_rate",
91-
[],
92-
trainable=False,
93-
initializer=init_ops.constant_initializer(learning_rate))
90+
if isinstance(learning_rate, ops.Tensor) and len(learning_rate.get_shape()) == 0:
91+
lr = learning_rate
92+
elif isinstance(learning_rate, float):
93+
lr = vs.get_variable("learning_rate",
94+
[],
95+
trainable=False,
96+
initializer=init_ops.constant_initializer(learning_rate))
97+
else:
98+
raise ValueError("Learning rate should be 0d Tensor or float. Got %s" %
99+
str(learning_rate))
94100
if learning_rate_decay_fn is not None:
95101
lr = learning_rate_decay_fn(lr, global_step)
96102

@@ -147,3 +153,4 @@ class should be sub-class of tf.Optimizer that implements
147153
train_tensor = control_flow_ops.with_dependencies([grad_updates], final_loss)
148154

149155
return train_tensor
156+

tensorflow/contrib/layers/python/layers/optimizers_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ def testGradientClip(self):
8080

8181
if __name__ == "__main__":
8282
tf.test.main()
83+

tensorflow/contrib/skflow/BUILD

+28
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ py_library(
1818

1919
py_test(
2020
name = "test_base",
21+
size = "small",
2122
srcs = ["python/skflow/tests/test_base.py"],
2223
srcs_version = "PY2AND3",
2324
deps = [
@@ -29,6 +30,7 @@ py_test(
2930

3031
py_test(
3132
name = "test_custom_decay",
33+
size = "small",
3234
srcs = ["python/skflow/tests/test_custom_decay.py"],
3335
srcs_version = "PY2AND3",
3436
deps = [
@@ -40,6 +42,7 @@ py_test(
4042

4143
py_test(
4244
name = "test_data_feeder",
45+
size = "small",
4346
srcs = ["python/skflow/tests/test_data_feeder.py"],
4447
srcs_version = "PY2AND3",
4548
deps = [
@@ -49,8 +52,21 @@ py_test(
4952
],
5053
)
5154

55+
py_test(
56+
name = "test_early_stopping",
57+
size = "medium",
58+
srcs = ["python/skflow/tests/test_early_stopping.py"],
59+
srcs_version = "PY2AND3",
60+
deps = [
61+
":skflow",
62+
"//tensorflow:tensorflow_py",
63+
"//tensorflow/python:framework_test_lib",
64+
],
65+
)
66+
5267
py_test(
5368
name = "test_estimators",
69+
size = "small",
5470
srcs = ["python/skflow/tests/test_estimators.py"],
5571
srcs_version = "PY2AND3",
5672
deps = [
@@ -62,6 +78,7 @@ py_test(
6278

6379
py_test(
6480
name = "test_grid_search",
81+
size = "small",
6582
srcs = ["python/skflow/tests/test_grid_search.py"],
6683
srcs_version = "PY2AND3",
6784
deps = [
@@ -73,6 +90,7 @@ py_test(
7390

7491
py_test(
7592
name = "test_io",
93+
size = "small",
7694
srcs = ["python/skflow/tests/test_io.py"],
7795
srcs_version = "PY2AND3",
7896
deps = [
@@ -84,6 +102,7 @@ py_test(
84102

85103
py_test(
86104
name = "test_multioutput",
105+
size = "small",
87106
srcs = ["python/skflow/tests/test_multioutput.py"],
88107
srcs_version = "PY2AND3",
89108
deps = [
@@ -95,6 +114,7 @@ py_test(
95114

96115
py_test(
97116
name = "test_nonlinear",
117+
size = "medium",
98118
srcs = ["python/skflow/tests/test_nonlinear.py"],
99119
srcs_version = "PY2AND3",
100120
deps = [
@@ -106,6 +126,7 @@ py_test(
106126

107127
py_test(
108128
name = "test_regression",
129+
size = "small",
109130
srcs = ["python/skflow/tests/test_regression.py"],
110131
srcs_version = "PY2AND3",
111132
deps = [
@@ -117,6 +138,7 @@ py_test(
117138

118139
py_test(
119140
name = "test_saver",
141+
size = "small",
120142
srcs = ["python/skflow/tests/test_saver.py"],
121143
srcs_version = "PY2AND3",
122144
deps = [
@@ -128,6 +150,7 @@ py_test(
128150

129151
py_test(
130152
name = "test_ops",
153+
size = "small",
131154
srcs = ["python/skflow/ops/tests/test_ops.py"],
132155
srcs_version = "PY2AND3",
133156
deps = [
@@ -139,6 +162,7 @@ py_test(
139162

140163
py_test(
141164
name = "test_dropout_ops",
165+
size = "small",
142166
srcs = ["python/skflow/ops/tests/test_dropout_ops.py"],
143167
srcs_version = "PY2AND3",
144168
deps = [
@@ -150,6 +174,7 @@ py_test(
150174

151175
py_test(
152176
name = "test_seq2seq_ops",
177+
size = "small",
153178
srcs = ["python/skflow/ops/tests/test_seq2seq_ops.py"],
154179
srcs_version = "PY2AND3",
155180
deps = [
@@ -161,6 +186,7 @@ py_test(
161186

162187
py_test(
163188
name = "test_categorical",
189+
size = "small",
164190
srcs = ["python/skflow/preprocessing/tests/test_categorical.py"],
165191
srcs_version = "PY2AND3",
166192
deps = [
@@ -172,6 +198,7 @@ py_test(
172198

173199
py_test(
174200
name = "test_categorical_vocabulary",
201+
size = "small",
175202
srcs = ["python/skflow/preprocessing/tests/test_categorical_vocabulary.py"],
176203
srcs_version = "PY2AND3",
177204
deps = [
@@ -183,6 +210,7 @@ py_test(
183210

184211
py_test(
185212
name = "test_text",
213+
size = "small",
186214
srcs = ["python/skflow/preprocessing/tests/test_text.py"],
187215
srcs_version = "PY2AND3",
188216
deps = [

tensorflow/contrib/skflow/__init__.py.orig

-23
This file was deleted.

tensorflow/contrib/skflow/python/__init__.py.orig

-23
This file was deleted.

tensorflow/contrib/skflow/python/skflow/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,3 @@
3636
from tensorflow.contrib.skflow.python.skflow import ops
3737
from tensorflow.contrib.skflow.python.skflow import preprocessing
3838
from tensorflow.contrib.skflow.python.skflow import models
39-
from tensorflow.contrib.skflow.python.skflow.trainer import TensorFlowTrainer

tensorflow/contrib/skflow/python/skflow/addons/__init__.py

-19
This file was deleted.

tensorflow/contrib/skflow/python/skflow/addons/config_addon.py

-40
This file was deleted.

tensorflow/contrib/skflow/python/skflow/estimators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from tensorflow.contrib.skflow.python.skflow.estimators.dnn import TensorFlowDNNRegressor
2626
from tensorflow.contrib.skflow.python.skflow.estimators.rnn import TensorFlowRNNClassifier
2727
from tensorflow.contrib.skflow.python.skflow.estimators.rnn import TensorFlowRNNRegressor
28+
from tensorflow.contrib.skflow.python.skflow.estimators.run_config import RunConfig

0 commit comments

Comments
 (0)