Skip to content

Commit 0d350f9

Browse files
terrytangyuanilblackdragon
authored andcommitted
[tf.learn] High-level DNNAutoencoder (tensorflow#2088)
* Added BaseTransformer and DNNAutoencoder * Fix conflict and added example
1 parent 5fa51de commit 0d350f9

File tree

10 files changed

+276
-4
lines changed

10 files changed

+276
-4
lines changed

tensorflow/contrib/learn/python/learn/estimators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import division
1717
from __future__ import print_function
1818

19-
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
19+
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer
2020
from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowLinearClassifier
2121
from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowClassifier
2222
from tensorflow.contrib.learn.python.learn.estimators.linear import TensorFlowLinearRegressor
@@ -25,4 +25,5 @@
2525
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNRegressor
2626
from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNClassifier
2727
from tensorflow.contrib.learn.python.learn.estimators.rnn import TensorFlowRNNRegressor
28+
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
2829
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig

tensorflow/contrib/learn/python/learn/estimators/_sklearn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ class _RegressorMixin():
111111
"""Mixin class for all regression estimators."""
112112
pass
113113

114+
class _TransformerMixin():
115+
"""Mixin class for all transformer estimators."""
114116

115117
class _NotFittedError(ValueError, AttributeError):
116118
"""Exception class to raise if estimator is used before fitting.
@@ -167,10 +169,11 @@ def _train_test_split(*args, **options):
167169
result += [x.take(train_idx, axis=0), x.take(test_idx, axis=0)]
168170
return tuple(result)
169171

172+
170173
# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
171174
TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
172175
if TRY_IMPORT_SKLEARN:
173-
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
176+
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
174177
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
175178
from sklearn.cross_validation import train_test_split
176179
try:
@@ -185,6 +188,7 @@ def _train_test_split(*args, **options):
185188
BaseEstimator = _BaseEstimator
186189
ClassifierMixin = _ClassifierMixin
187190
RegressorMixin = _RegressorMixin
191+
TransformerMixin = _TransformerMixin
188192
NotFittedError = _NotFittedError
189193
accuracy_score = _accuracy_score
190194
log_loss = None
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Deep Autoencoder estimators."""
2+
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
from tensorflow.python.ops import nn
20+
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
21+
from tensorflow.contrib.learn.python.learn import models
22+
23+
24+
class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
25+
"""TensorFlow Autoencoder Regressor model.
26+
27+
Parameters:
28+
hidden_units: List of hidden units per layer.
29+
batch_size: Mini batch size.
30+
activation: activation function used to map inner latent layer onto
31+
reconstruction layer.
32+
add_noise: a function that adds noise to tensor_in,
33+
e.g. def add_noise(x):
34+
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
35+
steps: Number of steps to run over data.
36+
optimizer: Optimizer name (or class), for example "SGD", "Adam",
37+
"Adagrad".
38+
learning_rate: If this is constant float value, no decay function is used.
39+
Instead, a customized decay function can be passed that accepts
40+
global_step as parameter and returns a Tensor.
41+
e.g. exponential decay function:
42+
def exp_decay(global_step):
43+
return tf.train.exponential_decay(
44+
learning_rate=0.1, global_step,
45+
decay_steps=2, decay_rate=0.001)
46+
continue_training: when continue_training is True, once initialized
47+
model will be continuely trained on every call of fit.
48+
config: RunConfig object that controls the configurations of the session,
49+
e.g. num_cores, gpu_memory_fraction, etc.
50+
verbose: Controls the verbosity, possible values:
51+
0: the algorithm and debug information is muted.
52+
1: trainer prints the progress.
53+
2: log device placement is printed.
54+
dropout: When not None, the probability we will drop out a given
55+
coordinate.
56+
"""
57+
def __init__(self, hidden_units, n_classes=0, batch_size=32,
58+
steps=200, optimizer="Adagrad", learning_rate=0.1,
59+
clip_gradients=5.0, activation=nn.relu, add_noise=None,
60+
continue_training=False, config=None,
61+
verbose=1, dropout=None):
62+
self.hidden_units = hidden_units
63+
self.dropout = dropout
64+
self.activation = activation
65+
self.add_noise = add_noise
66+
super(TensorFlowDNNAutoencoder, self).__init__(
67+
model_fn=self._model_fn,
68+
n_classes=n_classes,
69+
batch_size=batch_size, steps=steps, optimizer=optimizer,
70+
learning_rate=learning_rate, clip_gradients=clip_gradients,
71+
continue_training=continue_training,
72+
config=config, verbose=verbose)
73+
74+
def _model_fn(self, X, y):
75+
encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
76+
self.hidden_units,
77+
models.linear_regression,
78+
activation=self.activation,
79+
add_noise=self.add_noise,
80+
dropout=self.dropout)(X)
81+
self.encoder = encoder
82+
self.decoder = decoder
83+
return autoencoder_estimator
84+
85+
def generate(self, hidden=None):
86+
"""Generate new data using trained construction layer"""
87+
if hidden is None:
88+
last_layer = len(self.hidden_units) - 1
89+
bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer)
90+
import numpy as np
91+
hidden = np.random.normal(size=bias.shape)
92+
hidden = np.reshape(hidden, (1, len(hidden)))
93+
return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
94+
95+
@property
96+
def weights_(self):
97+
"""Returns weights of the autoencoder's weight layers."""
98+
weights = []
99+
for layer in range(len(self.hidden_units)):
100+
weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer))
101+
for layer in range(len(self.hidden_units)):
102+
weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer))
103+
weights.append(self.get_tensor_value('linear_regression/weights:0'))
104+
return weights
105+
106+
@property
107+
def bias_(self):
108+
"""Returns bias of the autoencoder's bias layers."""
109+
biases = []
110+
for layer in range(len(self.hidden_units)):
111+
biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer))
112+
for layer in range(len(self.hidden_units)):
113+
biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer))
114+
biases.append(self.get_tensor_value('linear_regression/bias:0'))
115+
return biases
116+

tensorflow/contrib/learn/python/learn/estimators/base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def _setup_training(self):
148148
# Add histograms for X and y if they are floats.
149149
if self._data_feeder.input_dtype in (np.float32, np.float64):
150150
logging_ops.histogram_summary('X', self._inp)
151-
if self._data_feeder.output_dtype in (np.float32, np.float64):
151+
if self._data_feeder.output_dtype in (np.float32, np.float64)\
152+
and self._out is not None:
152153
logging_ops.histogram_summary('y', self._out)
153154

154155
# Create model's graph.
@@ -408,7 +409,8 @@ def _setup_training(self):
408409
# Add histograms for X and y if they are floats.
409410
if self._data_feeder.input_dtype in (np.float32, np.float64):
410411
logging_ops.histogram_summary("X", self._inp)
411-
if self._data_feeder.output_dtype in (np.float32, np.float64):
412+
if self._data_feeder.output_dtype in (np.float32, np.float64)\
413+
and self._out is not None:
412414
logging_ops.histogram_summary("y", self._out)
413415

414416
# Create model's graph.
@@ -959,3 +961,18 @@ def restore(cls, path, config=None):
959961
estimator = getattr(estimators, class_name)(**model_def)
960962
estimator._restore(path)
961963
return estimator
964+
965+
966+
class TensorFlowBaseTransformer(TensorFlowEstimator, _sklearn.TransformerMixin):
967+
"""TensorFlow Base Transformer class."""
968+
def transform(self, X):
969+
"""Transform X using trained transformer."""
970+
return(super(TensorFlowBaseTransformer, self).predict(X, axis=1, batch_size=None))
971+
972+
def fit(self, X, y=None, monitor=None, logdir=None):
973+
"""Fit a transformer."""
974+
return(super(TensorFlowBaseTransformer, self).fit(X, y, monitor=None, logdir=None))
975+
976+
def fit_transform(self, X, y=None, monitor=None, logdir=None):
977+
"""Fit transformer and transform X using trained transformer."""
978+
return(self.fit(X, y, monitor=None, logdir=None).transform(X))

tensorflow/contrib/learn/python/learn/models.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
2020
from tensorflow.contrib.learn.python.learn.ops import losses_ops
21+
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
2122
from tensorflow.python.framework import dtypes
2223
from tensorflow.python.framework import ops
2324
from tensorflow.python.ops import array_ops as array_ops_
@@ -187,6 +188,36 @@ def dnn_estimator(X, y):
187188

188189
return dnn_estimator
189190

191+
def get_autoencoder_model(hidden_units, target_predictor_fn,
192+
activation, add_noise=None, dropout=None):
193+
"""Returns a function that creates a Autoencoder TensorFlow subgraph with given
194+
params.
195+
196+
Args:
197+
hidden_units: List of values of hidden units for layers.
198+
target_predictor_fn: Function that will predict target from input
199+
features. This can be logistic regression,
200+
linear regression or any other model,
201+
that takes X, y and returns predictions and loss tensors.
202+
activation: activation function used to map inner latent layer onto
203+
reconstruction layer.
204+
add_noise: a function that adds noise to tensor_in,
205+
e.g. def add_noise(x):
206+
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
207+
dropout: When not none, causes dropout regularization to be used,
208+
with the specified probability of removing a given coordinate.
209+
210+
Returns:
211+
A function that creates the subgraph.
212+
"""
213+
def dnn_autoencoder_estimator(X):
214+
"""Autoencoder estimator with target predictor function on top."""
215+
encoder, decoder = autoencoder_ops.dnn_autoencoder(
216+
X, hidden_units, activation,
217+
add_noise=add_noise, dropout=dropout)
218+
return encoder, decoder, target_predictor_fn(X, decoder)
219+
return dnn_autoencoder_estimator
220+
190221
## This will be in Tensorflow 0.7.
191222
## TODO(ilblackdragon): Clean this up when it's released
192223

tensorflow/contrib/learn/python/learn/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tensorflow.contrib.learn.python.learn.ops.array_ops import *
2020
from tensorflow.contrib.learn.python.learn.ops.conv_ops import *
2121
from tensorflow.contrib.learn.python.learn.ops.dnn_ops import *
22+
from tensorflow.contrib.learn.python.learn.ops.autoencoder_ops import *
2223
from tensorflow.contrib.learn.python.learn.ops.dropout_ops import *
2324
from tensorflow.contrib.learn.python.learn.ops.embeddings_ops import *
2425
from tensorflow.contrib.learn.python.learn.ops.losses_ops import *
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""TensorFlow ops for autoencoder."""
2+
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
from tensorflow.python.ops import nn
20+
from tensorflow.python.ops import variable_scope as vs
21+
from tensorflow.contrib.learn.python.learn.ops import dnn_ops
22+
23+
24+
def dnn_autoencoder(tensor_in, hidden_units,
25+
activation=nn.relu, add_noise=None,
26+
dropout=None, scope=None):
27+
"""Creates fully connected autoencoder subgraph.
28+
29+
Args:
30+
tensor_in: tensor or placeholder for input features.
31+
hidden_units: list of counts of hidden units in each layer.
32+
activation: activation function used to map inner latent layer onto
33+
reconstruction layer.
34+
add_noise: a function that adds noise to tensor_in,
35+
e.g. def add_noise(x):
36+
return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
37+
dropout: if not None, will add a dropout layer with given
38+
probability.
39+
scope: the variable scope for this op.
40+
41+
Returns:
42+
Tensors for encoder and decoder.
43+
"""
44+
with vs.variable_op_scope([tensor_in], scope, "autoencoder"):
45+
if add_noise is not None:
46+
tensor_in = add_noise(tensor_in)
47+
with vs.variable_scope('encoder'):
48+
# build DNN encoder
49+
encoder = dnn_ops.dnn(tensor_in, hidden_units,
50+
activation=activation, dropout=dropout)
51+
with vs.variable_scope('decoder'):
52+
# reverse hidden_units and built DNN decoder
53+
decoder = dnn_ops.dnn(encoder, hidden_units[::-1],
54+
activation=activation, dropout=dropout)
55+
return encoder, decoder
56+

tensorflow/contrib/learn/python/learn/tests/test_nonlinear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ def input_fn(X):
165165
5, 6]])))
166166
self.assertAllClose(predictions, np.array([1, 0]))
167167

168+
def testDNNAutoencoder(self):
169+
import numpy as np
170+
iris = datasets.load_iris()
171+
autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20])
172+
transformed = autoencoder.fit_transform(iris.data[1:2])
173+
expected = np.array([[ -3.57627869e-07, 1.17000043e+00, 1.01902664e+00, 1.19209290e-07,
174+
0.00000000e+00, 1.19209290e-07, -5.96046448e-08, -2.38418579e-07,
175+
9.74681854e-01, 1.19209290e-07]])
176+
self.assertAllClose(transformed, expected)
177+
168178

169179
if __name__ == "__main__":
170180
tf.test.main()

tensorflow/examples/skflow/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Some examples use the `pandas` library for data processing (`sudo pip install pa
1010
* [Deep Neural Network Regression with Boston Data](boston.py)
1111
* [Convolutional Neural Networks with Digits Data](digits.py)
1212
* [Deep Neural Network Classification with Iris Data](iris.py)
13+
* [Deep Neural Network Autoencoder with Iris Data](dnn_autoencoder_iris.py)
1314
* [Grid search and Deep Neural Network Classification](iris_gridsearch_cv.py)
1415
* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py)
1516
* [Building A Custom Model](iris_custom_model.py)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import absolute_import
15+
from __future__ import division
16+
from __future__ import print_function
17+
18+
import random
19+
20+
import tensorflow as tf
21+
from tensorflow.contrib.learn.python import learn
22+
from tensorflow.contrib.learn.python.learn import datasets
23+
24+
# Load Iris Data
25+
iris = datasets.load_iris()
26+
27+
# Initialize a deep neural network autoencoder
28+
# You can also add noise and add dropout if needed
29+
# Details see TensorFlowDNNAutoencoder documentation.
30+
autoencoder = learn.TensorFlowDNNAutoencoder(hidden_units=[10, 20])
31+
32+
# Fit with Iris data
33+
transformed = autoencoder.fit_transform(iris.data)
34+
35+
print(transformed)

0 commit comments

Comments
 (0)