Skip to content

tf2使用textcnn网络问题 #9

@Smile-L-up

Description

@Smile-L-up

tf1 版本转tf2问题,当不添加textcnn网络时,训练预测均没有问题。但是当加入textcnn时训练时loss与acc都不错,但是预测都是错误的。以下tf2实现的textcnn基本都是直接转的。此外我还尝试tf.keras.layers.Conv2D()以及conv1d实现。但是效果都不行,本来考虑是不是训练周期等参数问题,但是跟您的项目参数保持一致,训练出来的模型就是有问题(有进行dropout),所以想请教一下您。

def textcnn(x):
    pooled_outputs = []

    filter_sizes = [2, 3, 4, 5, 6, 7]
    inputs_expand = tf.expand_dims(x, -1)
    for filter_size in filter_sizes:
        filter_shape = [filter_size, 312, 1, 128]
        W = tf.Variable(tf.random.truncated_normal(filter_shape, stddev=0.1), dtype=tf.float32, name="W")
        b = tf.Variable(tf.constant(0.1, shape=[128]), dtype=tf.float32, name="b")
        conv = tf.nn.conv2d(
            inputs_expand,
            W,
            strides=[1, 1, 1, 1],
            padding="VALID",
            name="conv")
        # Apply nonlinearity
        h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
        # Maxpooling over the outputs
        pooled = tf.nn.max_pool(
            h,
            ksize=[1, 60 - filter_size + 1, 1, 1],
            strides=[1, 1, 1, 1],
            padding='VALID',
            name="pool")
        pooled_outputs.append(pooled)
    # Combine all the pooled features
    num_filters_total = 128 * len(filter_sizes)
    h_pool = tf.concat(pooled_outputs, 3)
    h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])

    return h_pool_flat

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions