Skip to content

Commit 932b640

Browse files
yuwen-yankpe
authored andcommitted
add text CNN model for text classification problem (tensorflow#1271)
1 parent 40056c4 commit 932b640

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

tensor2tensor/models/text_cnn.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""TextCNN model from "Convolutional Neural Networks for Sentence Classification".
17+
"""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
from tensor2tensor.layers import common_hparams
24+
from tensor2tensor.layers import common_layers
25+
from tensor2tensor.utils import registry
26+
from tensor2tensor.utils import t2t_model
27+
28+
import tensorflow as tf
29+
30+
@registry.register_model
31+
class TextCNN(t2t_model.T2TModel):
32+
"""Text CNN."""
33+
34+
def body(self, features):
35+
"""TextCNN main model_fn.
36+
Args:
37+
features: Map of features to the model. Should contain the following:
38+
"inputs": Text inputs.
39+
[batch_size, input_length, 1, hidden_dim].
40+
"targets": Target encoder outputs.
41+
[batch_size, 1, 1, hidden_dim]
42+
Returns:
43+
Final encoder representation. [batch_size, 1, 1, hidden_dim]
44+
"""
45+
hparams = self._hparams
46+
inputs = features["inputs"]
47+
48+
xshape = common_layers.shape_list(inputs)
49+
50+
vocab_size = xshape[3]
51+
inputs = tf.reshape(inputs, [xshape[0], xshape[1], xshape[3], xshape[2]])
52+
53+
pooled_outputs = []
54+
for _, filter_size in enumerate(hparams.filter_sizes):
55+
with tf.name_scope("conv-maxpool-%s" % filter_size):
56+
filter_shape = [filter_size, vocab_size, 1, hparams.num_filters]
57+
filter_var = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
58+
filter_bias = tf.Variable(tf.constant(0.1, shape=[hparams.num_filters]), name="b")
59+
conv = tf.nn.conv2d(
60+
inputs,
61+
filter_var,
62+
strides=[1, 1, 1, 1],
63+
padding="VALID",
64+
name="conv")
65+
conv_outputs = tf.nn.relu(tf.nn.bias_add(conv, filter_bias), name="relu")
66+
pooled = tf.math.reduce_max(conv_outputs, axis=1, keepdims=True, name="max")
67+
pooled_outputs.append(pooled)
68+
69+
num_filters_total = hparams.num_filters * len(hparams.filter_sizes)
70+
h_pool = tf.concat(pooled_outputs, 3)
71+
h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])
72+
73+
# Add dropout
74+
output = tf.nn.dropout(h_pool_flat, 1 - hparams.output_dropout)
75+
output = tf.reshape(output, [-1, 1, 1, num_filters_total])
76+
77+
return output
78+
79+
@registry.register_hparams
80+
def text_cnn_base():
81+
"""Set of hyperparameters."""
82+
hparams = common_hparams.basic_params1()
83+
hparams.batch_size = 4096
84+
hparams.max_length = 256
85+
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
86+
hparams.optimizer_adam_epsilon = 1e-9
87+
hparams.learning_rate_schedule = "legacy"
88+
hparams.learning_rate_decay_scheme = "noam"
89+
hparams.learning_rate = 0.1
90+
hparams.learning_rate_warmup_steps = 4000
91+
hparams.initializer_gain = 1.0
92+
hparams.num_hidden_layers = 6
93+
hparams.initializer = "uniform_unit_scaling"
94+
hparams.weight_decay = 0.0
95+
hparams.optimizer_adam_beta1 = 0.9
96+
hparams.optimizer_adam_beta2 = 0.98
97+
hparams.num_sampled_classes = 0
98+
hparams.label_smoothing = 0.1
99+
hparams.shared_embedding_and_softmax_weights = True
100+
hparams.symbol_modality_num_shards = 16
101+
102+
# Add new ones like this.
103+
hparams.add_hparam("filter_sizes", [2, 3, 4, 5])
104+
hparams.add_hparam("num_filters", 128)
105+
hparams.add_hparam("output_dropout", 0.4)
106+
return hparams

0 commit comments

Comments
 (0)