forked from isperfee/TextClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b032081
commit 11859d5
Showing
10 changed files
with
127,942 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import argparse | ||
import random | ||
from text_cnn import TextCNN | ||
from utils import * | ||
from tensorflow.keras.preprocessing import sequence | ||
from sklearn.model_selection import train_test_split | ||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint | ||
from tensorflow.keras.utils import to_categorical | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description="train") | ||
parser.add_argument("--data_dir", type=str, default="./data", help="data file") | ||
parser.add_argument("--vocab_file", type=str, default="./vocab/vocab.txt", help="vocab_file") | ||
parser.add_argument("--vocab_size", type=int, default="40000", help="vocab_size") | ||
parser.add_argument("--max_features", type=int, default=40001, help="max_features") | ||
parser.add_argument("--max_len", type=int, default=100, help="max_len") | ||
parser.add_argument("--batch_size", type=int, default=64, help="batch_size") | ||
parser.add_argument("--embedding_size", type=int, default=50, help="embedding_size") | ||
parser.add_argument("--epochs", type=int, default=3, help="epochs") | ||
args = parser.parse_args() | ||
|
||
logger.info('加载数据构建词汇表...') | ||
if not os.path.exists(args.vocab_file): | ||
build_vocab(args.data_dir, args.vocab_file, args.vocab_size) | ||
|
||
categories, cat_to_id = read_category() | ||
words, word_to_id = read_vocab(args.vocab_file) | ||
|
||
logger.info('加载数据...') | ||
data, label = read_files(args.data_dir) | ||
data = list(zip(data, label)) | ||
random.shuffle(data) | ||
|
||
train_data, test_data = train_test_split(data) | ||
|
||
data_train = encode_sentences([content[0] for content in train_data], word_to_id) | ||
label_train = to_categorical(encode_cate([content[1] for content in train_data], cat_to_id)) | ||
data_test = encode_sentences([content[0] for content in test_data], word_to_id) | ||
label_test = to_categorical(encode_cate([content[1] for content in test_data], cat_to_id)) | ||
|
||
data_train = sequence.pad_sequences(data_train, maxlen=args.max_len) | ||
data_test = sequence.pad_sequences(data_test, maxlen=args.max_len) | ||
|
||
model = TextCNN(args.max_len, args.max_features, args.embedding_size).get_model() | ||
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy']) | ||
|
||
logger.info('开始训练...') | ||
cnn_callbacks = [ | ||
ModelCheckpoint('./model.h5', verbose=1), | ||
EarlyStopping(monitor='val_accuracy', patience=2, mode='max') | ||
] | ||
|
||
history = model.fit(data_train, label_train, | ||
batch_size=args.batch_size, | ||
epochs=args.epochs, | ||
callbacks=cnn_callbacks, | ||
validation_data=(data_test, label_test)) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from tensorflow.keras import Input, Model | ||
from tensorflow.keras.layers import Dense, Conv1D, Embedding, Concatenate, GlobalMaxPooling1D | ||
|
||
|
||
class TextCNN(object): | ||
def __init__(self, max_len, max_features, embedding_dims, | ||
class_num=5, | ||
activation='softmax'): | ||
self.max_len = max_len | ||
self.max_features = max_features | ||
self.embedding_dims = embedding_dims | ||
self.class_num = class_num | ||
self.activation = activation | ||
|
||
def get_model(self): | ||
input = Input((self.max_len,)) | ||
embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.max_len)(input) | ||
convs = [] | ||
for kernel_size in [3, 4, 5]: | ||
c = Conv1D(128, kernel_size, activation='relu')(embedding) | ||
c = GlobalMaxPooling1D()(c) | ||
convs.append(c) | ||
x = Concatenate()(convs) | ||
|
||
output = Dense(self.class_num, activation=self.activation)(x) | ||
model = Model(inputs=input, outputs=output) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import numpy as np | ||
from collections import Counter | ||
from tensorflow.keras.preprocessing.sequence import pad_sequences | ||
from tensorflow.keras.utils import to_categorical | ||
|
||
|
||
def read_file(file_name): | ||
contents = [] | ||
labels = [] | ||
with open(file_name, 'r', encoding='utf-8', errors='ignore') as f: | ||
for line in f: | ||
try: | ||
raw = line.strip().split("\t") | ||
content = raw[1].split(' ') | ||
if content: | ||
contents.append(content) | ||
labels.append(raw[0]) | ||
except: | ||
pass | ||
return contents, labels | ||
|
||
|
||
def read_single_file(file_name): | ||
contents = [] | ||
label = file_name.split('/')[-1].split('.')[0] | ||
with open(file_name, 'r', encoding='utf-8', errors='ignore') as f: | ||
for line in f: | ||
try: | ||
content = line.strip().split(' ') | ||
if content: | ||
contents.append(content) | ||
except: | ||
pass | ||
return contents, label | ||
|
||
|
||
def read_files(directory): | ||
contents = [] | ||
labels = [] | ||
files = [f for f in os.listdir(directory) if f.endswith(".txt")] | ||
for file_name in files: | ||
content, label = read_single_file(os.path.join(directory, file_name)) | ||
contents.extend(content) | ||
labels.extend([label] * len(content)) | ||
return contents, labels | ||
|
||
|
||
def build_vocab(train_dir, vocab_file, vocab_size=5000): | ||
data_train, _ = read_files(train_dir) | ||
|
||
all_data = [] | ||
for content in data_train: | ||
all_data.extend(content) | ||
|
||
counter = Counter(all_data) | ||
count_pairs = counter.most_common(vocab_size - 1) | ||
words, _ = list(zip(*count_pairs)) | ||
words = ['<PAD>'] + list(words) | ||
open(vocab_file, mode='w', encoding='utf-8', errors='ignore').write('\n'.join(words) + '\n') | ||
|
||
|
||
def read_vocab(vocab_file): | ||
with open(vocab_file, mode='r', encoding='utf-8', errors='ignore') as fp: | ||
words = [_.strip() for _ in fp.readlines()] | ||
word_to_id = dict(zip(words, range(len(words)))) | ||
return words, word_to_id | ||
|
||
|
||
def read_category(): | ||
categories = ['car', 'entertainment', 'military', 'sports', 'technology'] | ||
cat_to_id = dict(zip(categories, range(len(categories)))) | ||
return categories, cat_to_id | ||
|
||
|
||
def encode_sentences(contents, words): | ||
return [encode_cate(x, words) for x in contents] | ||
|
||
|
||
def encode_cate(content, words): | ||
return [(words[x] if x in words else 40000) for x in content] | ||
|
||
|
||
def process_file(file_name, word_to_id, cat_to_id, max_length=600): | ||
contents, labels = read_file(file_name) | ||
|
||
data_id, label_id = [], [] | ||
for i in range(len(contents)): | ||
data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id]) | ||
label_id.append(cat_to_id[labels[i]]) | ||
|
||
x_pad = pad_sequences(data_id, max_length) | ||
y_pad = to_categorical(label_id, num_classes=len(cat_to_id)) | ||
|
||
return x_pad, y_pad | ||
|
||
|
||
def batch_iter(x, y, batch_size=64): | ||
data_len = len(x) | ||
num_batch = int((data_len - 1) / batch_size) + 1 | ||
|
||
indices = np.random.permutation(np.arange(data_len)) | ||
x_shuffle = x[indices] | ||
y_shuffle = y[indices] | ||
|
||
for i in range(num_batch): | ||
start_id = i * batch_size | ||
end_id = min((i + 1) * batch_size, data_len) | ||
yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id] |
Oops, something went wrong.