Skip to content

Commit 7d93e39

Browse files
yuwen-yankpe
authored andcommitted
internal merge of PR tensorflow#1271
PiperOrigin-RevId: 224227181
1 parent 932b640 commit 7d93e39

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

tensor2tensor/models/text_cnn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""TextCNN model from "Convolutional Neural Networks for Sentence Classification".
17-
"""
16+
"""TextCNN (see Convolutional Neural Networks for Sentence Classification)."""
1817

1918
from __future__ import absolute_import
2019
from __future__ import division
@@ -27,12 +26,14 @@
2726

2827
import tensorflow as tf
2928

29+
3030
@registry.register_model
3131
class TextCNN(t2t_model.T2TModel):
3232
"""Text CNN."""
3333

3434
def body(self, features):
3535
"""TextCNN main model_fn.
36+
3637
Args:
3738
features: Map of features to the model. Should contain the following:
3839
"inputs": Text inputs.
@@ -54,16 +55,20 @@ def body(self, features):
5455
for _, filter_size in enumerate(hparams.filter_sizes):
5556
with tf.name_scope("conv-maxpool-%s" % filter_size):
5657
filter_shape = [filter_size, vocab_size, 1, hparams.num_filters]
57-
filter_var = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
58-
filter_bias = tf.Variable(tf.constant(0.1, shape=[hparams.num_filters]), name="b")
58+
filter_var = tf.Variable(
59+
tf.truncated_normal(filter_shape, stddev=0.1), name="W")
60+
filter_bias = tf.Variable(
61+
tf.constant(0.1, shape=[hparams.num_filters]), name="b")
5962
conv = tf.nn.conv2d(
6063
inputs,
6164
filter_var,
6265
strides=[1, 1, 1, 1],
6366
padding="VALID",
6467
name="conv")
65-
conv_outputs = tf.nn.relu(tf.nn.bias_add(conv, filter_bias), name="relu")
66-
pooled = tf.math.reduce_max(conv_outputs, axis=1, keepdims=True, name="max")
68+
conv_outputs = tf.nn.relu(
69+
tf.nn.bias_add(conv, filter_bias), name="relu")
70+
pooled = tf.math.reduce_max(
71+
conv_outputs, axis=1, keepdims=True, name="max")
6772
pooled_outputs.append(pooled)
6873

6974
num_filters_total = hparams.num_filters * len(hparams.filter_sizes)
@@ -76,6 +81,7 @@ def body(self, features):
7681

7782
return output
7883

84+
7985
@registry.register_hparams
8086
def text_cnn_base():
8187
"""Set of hyperparameters."""

0 commit comments

Comments
 (0)