Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

add text CNN model for text classification problem #1271

Merged
merged 1 commit into from
Dec 5, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tensor2tensor/models/text_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TextCNN model from "Convolutional Neural Networks for Sentence Classification".
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.layers import common_hparams
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf

@registry.register_model
class TextCNN(t2t_model.T2TModel):
"""Text CNN."""

def body(self, features):
"""TextCNN main model_fn.
Args:
features: Map of features to the model. Should contain the following:
"inputs": Text inputs.
[batch_size, input_length, 1, hidden_dim].
"targets": Target encoder outputs.
[batch_size, 1, 1, hidden_dim]
Returns:
Final encoder representation. [batch_size, 1, 1, hidden_dim]
"""
hparams = self._hparams
inputs = features["inputs"]

xshape = common_layers.shape_list(inputs)

vocab_size = xshape[3]
inputs = tf.reshape(inputs, [xshape[0], xshape[1], xshape[3], xshape[2]])

pooled_outputs = []
for _, filter_size in enumerate(hparams.filter_sizes):
with tf.name_scope("conv-maxpool-%s" % filter_size):
filter_shape = [filter_size, vocab_size, 1, hparams.num_filters]
filter_var = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
filter_bias = tf.Variable(tf.constant(0.1, shape=[hparams.num_filters]), name="b")
conv = tf.nn.conv2d(
inputs,
filter_var,
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
conv_outputs = tf.nn.relu(tf.nn.bias_add(conv, filter_bias), name="relu")
pooled = tf.math.reduce_max(conv_outputs, axis=1, keepdims=True, name="max")
pooled_outputs.append(pooled)

num_filters_total = hparams.num_filters * len(hparams.filter_sizes)
h_pool = tf.concat(pooled_outputs, 3)
h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])

# Add dropout
output = tf.nn.dropout(h_pool_flat, 1 - hparams.output_dropout)
output = tf.reshape(output, [-1, 1, 1, num_filters_total])

return output

@registry.register_hparams
def text_cnn_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.batch_size = 4096
hparams.max_length = 256
hparams.clip_grad_norm = 0. # i.e. no gradient clipping
hparams.optimizer_adam_epsilon = 1e-9
hparams.learning_rate_schedule = "legacy"
hparams.learning_rate_decay_scheme = "noam"
hparams.learning_rate = 0.1
hparams.learning_rate_warmup_steps = 4000
hparams.initializer_gain = 1.0
hparams.num_hidden_layers = 6
hparams.initializer = "uniform_unit_scaling"
hparams.weight_decay = 0.0
hparams.optimizer_adam_beta1 = 0.9
hparams.optimizer_adam_beta2 = 0.98
hparams.num_sampled_classes = 0
hparams.label_smoothing = 0.1
hparams.shared_embedding_and_softmax_weights = True
hparams.symbol_modality_num_shards = 16

# Add new ones like this.
hparams.add_hparam("filter_sizes", [2, 3, 4, 5])
hparams.add_hparam("num_filters", 128)
hparams.add_hparam("output_dropout", 0.4)
return hparams