Skip to content

Commit

Permalink
text cnn classification
Browse files Browse the repository at this point in the history
  • Loading branch information
pantheonzhang committed Jul 23, 2020
1 parent b032081 commit 11859d5
Show file tree
Hide file tree
Showing 10 changed files with 127,942 additions and 0 deletions.
10,782 changes: 10,782 additions & 0 deletions TextCNN/data/car.txt

Large diffs are not rendered by default.

20,082 changes: 20,082 additions & 0 deletions TextCNN/data/entertainment.txt

Large diffs are not rendered by default.

16,838 changes: 16,838 additions & 0 deletions TextCNN/data/military.txt

Large diffs are not rendered by default.

20,011 changes: 20,011 additions & 0 deletions TextCNN/data/sports.txt

Large diffs are not rendered by default.

20,034 changes: 20,034 additions & 0 deletions TextCNN/data/technology.txt

Large diffs are not rendered by default.

59 changes: 59 additions & 0 deletions TextCNN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse
import random
from text_cnn import TextCNN
from utils import *
from tensorflow.keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
import logging
logger = logging.getLogger(__name__)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="train")
parser.add_argument("--data_dir", type=str, default="./data", help="data file")
parser.add_argument("--vocab_file", type=str, default="./vocab/vocab.txt", help="vocab_file")
parser.add_argument("--vocab_size", type=int, default="40000", help="vocab_size")
parser.add_argument("--max_features", type=int, default=40001, help="max_features")
parser.add_argument("--max_len", type=int, default=100, help="max_len")
parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
parser.add_argument("--embedding_size", type=int, default=50, help="embedding_size")
parser.add_argument("--epochs", type=int, default=3, help="epochs")
args = parser.parse_args()

logger.info('加载数据构建词汇表...')
if not os.path.exists(args.vocab_file):
build_vocab(args.data_dir, args.vocab_file, args.vocab_size)

categories, cat_to_id = read_category()
words, word_to_id = read_vocab(args.vocab_file)

logger.info('加载数据...')
data, label = read_files(args.data_dir)
data = list(zip(data, label))
random.shuffle(data)

train_data, test_data = train_test_split(data)

data_train = encode_sentences([content[0] for content in train_data], word_to_id)
label_train = to_categorical(encode_cate([content[1] for content in train_data], cat_to_id))
data_test = encode_sentences([content[0] for content in test_data], word_to_id)
label_test = to_categorical(encode_cate([content[1] for content in test_data], cat_to_id))

data_train = sequence.pad_sequences(data_train, maxlen=args.max_len)
data_test = sequence.pad_sequences(data_test, maxlen=args.max_len)

model = TextCNN(args.max_len, args.max_features, args.embedding_size).get_model()
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])

logger.info('开始训练...')
cnn_callbacks = [
ModelCheckpoint('./model.h5', verbose=1),
EarlyStopping(monitor='val_accuracy', patience=2, mode='max')
]

history = model.fit(data_train, label_train,
batch_size=args.batch_size,
epochs=args.epochs,
callbacks=cnn_callbacks,
validation_data=(data_test, label_test))
Binary file added TextCNN/model.h5
Binary file not shown.
27 changes: 27 additions & 0 deletions TextCNN/text_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Conv1D, Embedding, Concatenate, GlobalMaxPooling1D


class TextCNN(object):
def __init__(self, max_len, max_features, embedding_dims,
class_num=5,
activation='softmax'):
self.max_len = max_len
self.max_features = max_features
self.embedding_dims = embedding_dims
self.class_num = class_num
self.activation = activation

def get_model(self):
input = Input((self.max_len,))
embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.max_len)(input)
convs = []
for kernel_size in [3, 4, 5]:
c = Conv1D(128, kernel_size, activation='relu')(embedding)
c = GlobalMaxPooling1D()(c)
convs.append(c)
x = Concatenate()(convs)

output = Dense(self.class_num, activation=self.activation)(x)
model = Model(inputs=input, outputs=output)
return model
109 changes: 109 additions & 0 deletions TextCNN/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import numpy as np
from collections import Counter
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical


def read_file(file_name):
contents = []
labels = []
with open(file_name, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
try:
raw = line.strip().split("\t")
content = raw[1].split(' ')
if content:
contents.append(content)
labels.append(raw[0])
except:
pass
return contents, labels


def read_single_file(file_name):
contents = []
label = file_name.split('/')[-1].split('.')[0]
with open(file_name, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
try:
content = line.strip().split(' ')
if content:
contents.append(content)
except:
pass
return contents, label


def read_files(directory):
contents = []
labels = []
files = [f for f in os.listdir(directory) if f.endswith(".txt")]
for file_name in files:
content, label = read_single_file(os.path.join(directory, file_name))
contents.extend(content)
labels.extend([label] * len(content))
return contents, labels


def build_vocab(train_dir, vocab_file, vocab_size=5000):
data_train, _ = read_files(train_dir)

all_data = []
for content in data_train:
all_data.extend(content)

counter = Counter(all_data)
count_pairs = counter.most_common(vocab_size - 1)
words, _ = list(zip(*count_pairs))
words = ['<PAD>'] + list(words)
open(vocab_file, mode='w', encoding='utf-8', errors='ignore').write('\n'.join(words) + '\n')


def read_vocab(vocab_file):
with open(vocab_file, mode='r', encoding='utf-8', errors='ignore') as fp:
words = [_.strip() for _ in fp.readlines()]
word_to_id = dict(zip(words, range(len(words))))
return words, word_to_id


def read_category():
categories = ['car', 'entertainment', 'military', 'sports', 'technology']
cat_to_id = dict(zip(categories, range(len(categories))))
return categories, cat_to_id


def encode_sentences(contents, words):
return [encode_cate(x, words) for x in contents]


def encode_cate(content, words):
return [(words[x] if x in words else 40000) for x in content]


def process_file(file_name, word_to_id, cat_to_id, max_length=600):
contents, labels = read_file(file_name)

data_id, label_id = [], []
for i in range(len(contents)):
data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
label_id.append(cat_to_id[labels[i]])

x_pad = pad_sequences(data_id, max_length)
y_pad = to_categorical(label_id, num_classes=len(cat_to_id))

return x_pad, y_pad


def batch_iter(x, y, batch_size=64):
data_len = len(x)
num_batch = int((data_len - 1) / batch_size) + 1

indices = np.random.permutation(np.arange(data_len))
x_shuffle = x[indices]
y_shuffle = y[indices]

for i in range(num_batch):
start_id = i * batch_size
end_id = min((i + 1) * batch_size, data_len)
yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]
Loading

0 comments on commit 11859d5

Please sign in to comment.