13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- """TextCNN model from "Convolutional Neural Networks for Sentence Classification".
17
- """
16
+ """TextCNN (see Convolutional Neural Networks for Sentence Classification)."""
18
17
19
18
from __future__ import absolute_import
20
19
from __future__ import division
27
26
28
27
import tensorflow as tf
29
28
29
+
30
30
@registry .register_model
31
31
class TextCNN (t2t_model .T2TModel ):
32
32
"""Text CNN."""
33
33
34
34
def body (self , features ):
35
35
"""TextCNN main model_fn.
36
+
36
37
Args:
37
38
features: Map of features to the model. Should contain the following:
38
39
"inputs": Text inputs.
@@ -54,16 +55,20 @@ def body(self, features):
54
55
for _ , filter_size in enumerate (hparams .filter_sizes ):
55
56
with tf .name_scope ("conv-maxpool-%s" % filter_size ):
56
57
filter_shape = [filter_size , vocab_size , 1 , hparams .num_filters ]
57
- filter_var = tf .Variable (tf .truncated_normal (filter_shape , stddev = 0.1 ), name = "W" )
58
- filter_bias = tf .Variable (tf .constant (0.1 , shape = [hparams .num_filters ]), name = "b" )
58
+ filter_var = tf .Variable (
59
+ tf .truncated_normal (filter_shape , stddev = 0.1 ), name = "W" )
60
+ filter_bias = tf .Variable (
61
+ tf .constant (0.1 , shape = [hparams .num_filters ]), name = "b" )
59
62
conv = tf .nn .conv2d (
60
63
inputs ,
61
64
filter_var ,
62
65
strides = [1 , 1 , 1 , 1 ],
63
66
padding = "VALID" ,
64
67
name = "conv" )
65
- conv_outputs = tf .nn .relu (tf .nn .bias_add (conv , filter_bias ), name = "relu" )
66
- pooled = tf .math .reduce_max (conv_outputs , axis = 1 , keepdims = True , name = "max" )
68
+ conv_outputs = tf .nn .relu (
69
+ tf .nn .bias_add (conv , filter_bias ), name = "relu" )
70
+ pooled = tf .math .reduce_max (
71
+ conv_outputs , axis = 1 , keepdims = True , name = "max" )
67
72
pooled_outputs .append (pooled )
68
73
69
74
num_filters_total = hparams .num_filters * len (hparams .filter_sizes )
@@ -76,6 +81,7 @@ def body(self, features):
76
81
77
82
return output
78
83
84
+
79
85
@registry .register_hparams
80
86
def text_cnn_base ():
81
87
"""Set of hyperparameters."""
0 commit comments