-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2b1d8ab
Showing
8 changed files
with
614 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
report/ | ||
data/ | ||
cnn-text-classification-pytorch-master/ | ||
.vscode/ | ||
__pycache__/ | ||
*.zip | ||
*.ckpt | ||
note.md | ||
test.py | ||
example.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# _*_ coding: utf-8 _*_ | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
from torch.nn import functional as F | ||
|
||
|
||
class RCNN(nn.Module): | ||
def __init__(self, config): | ||
super(RCNN, self).__init__() | ||
|
||
self.label_num = config.label_num | ||
self.hidden_size = 100 | ||
self.hidden_size_linear = 64 # dim after pooling | ||
self.embedding_length = config.wordvec_dim | ||
self.device = config.device | ||
self.embeddings = nn.Embedding(400000, self.embedding_length).to(config.device) | ||
self.embeddings = self.embeddings.from_pretrained(config.weight, freeze=False) | ||
self.lstm = nn.LSTM(input_size = self.embedding_length, | ||
hidden_size = self.hidden_size, | ||
num_layers = 1, | ||
bidirectional = True, batch_first=True) | ||
|
||
self.dropout = nn.Dropout(0.2) # 20% to be zeroed | ||
|
||
self.W = nn.Linear(2*self.hidden_size+self.embedding_length, self.hidden_size_linear) | ||
# self.W = nn.Linear(self.hidden_size+self.embedding_length, self.hidden_size_linear) | ||
|
||
self.tanh = nn.Tanh() | ||
self.fc = nn.Linear(self.hidden_size_linear, config.label_num) | ||
self.softmax = nn.Softmax() | ||
|
||
def forward(self, x): | ||
embedded_sent = self.embeddings(x) # (batch_size, seq_len, embed_size) | ||
|
||
lstm_out, (h_n, c_n) = self.lstm(embedded_sent) # (batch_size, seq_len, 2 * hidden_size) | ||
|
||
input_features = torch.cat([lstm_out, embedded_sent], 2) # (batch_size, seq_len, embed_size + 2*hidden_size) | ||
|
||
# The method described in the original paper, very slow | ||
# input_features = torch.zeros((x.size()[0], x.size()[1], self.hidden_size+self.embedding_length), device=self.device) | ||
# for j in range(x.size()[1]): | ||
# for h in range(self.hidden_size): | ||
# input_features[:, j, :] = torch.cat([lstm_out[:, j, :h], embedded_sent[:, j, :], lstm_out[:, j, h-self.hidden_size:]], dim=1) | ||
# input_features = torch.zeros((x.size()[0], x.size()[1], self.hidden_size+self.embedding_length), device=self.device) | ||
# for h in range(self.hidden_size): | ||
# input_features[:, :, :] = torch.cat([lstm_out[:, :, :h], embedded_sent[:, :, :], lstm_out[:, :, h-self.hidden_size:]], dim=2) | ||
|
||
|
||
linear_output = self.tanh(self.W(input_features)) # (batch_size, seq_len, hidden_size_linear) | ||
|
||
linear_output = linear_output.permute(0,2,1) | ||
|
||
max_out_features = F.max_pool1d(linear_output, linear_output.shape[2]).squeeze(2) # (batch_size, hidden_size_linear) | ||
|
||
# max_out_features = self.dropout(max_out_features) | ||
final_out = self.fc(max_out_features) | ||
# return F.softmax(final_out, dim=1) | ||
return final_out | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Text Classification on SST Dataset | ||
|
||
A simple text classification on SST dataset via PyTorch. | ||
|
||
## Models | ||
|
||
- Sentence CNN | ||
- Vanilla RNN | ||
- (Bidirectional) LSTM | ||
- RCNN (Recurrent Convolutional Neural Network) | ||
- Ensembling the above | ||
|
||
## Accuracy | ||
|
||
Model| Embedding |Fine-tuning | SST-2 | SST-5 | ||
--|--|--|--|-- | ||
CNN | GloVe 6B.50d | N | 71.95 | N/A | ||
CNN | GloVe 6B.50d | Y | 78.36 | 42.85 | ||
CNN | GloVe 6B.300d | Y | 78.42 | 43.10 | ||
RNN | GloVe 6B.50d | Y | 73.26 | 38.78 | ||
RNN | GloVe 6B.300d | Y | 75.88 | 38.64 | ||
LSTM | GloVe 6B.50d | N | 74.21 | N/A | ||
LSTM | GloVe 6B.50d | Y | 75.97 | 39.37 | ||
LSTM | GloVe 6B.300d | Y | 78.05 | 40.54 | ||
RCNN | GloVe 6B.50d | N | 75.52 | N/A | ||
RCNN | GloVe 6B.300d | Y | **80.41** | **45.02** | ||
Ensemble | GloVe 6B.300d | Y | **81.36** | **46.02** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
from torch.nn import functional as F | ||
|
||
|
||
class myRNN(nn.Module): | ||
def __init__(self, config): | ||
super(myRNN, self).__init__() | ||
|
||
self.hidden_size = 30 # Dimension of hidden state | ||
self.embedding_length = config.wordvec_dim | ||
self.num_layers = 2 # Stack layers | ||
self.embedding_length = config.wordvec_dim | ||
self.word_embeddings = nn.Embedding( | ||
400000, self.embedding_length) # Embedding layer | ||
self.word_embeddings = self.word_embeddings.from_pretrained( | ||
config.weight, freeze=False) # Load pretrianed word embedding, and fine-tuing | ||
self.recurrent = nn.RNN(self.embedding_length, self.hidden_size, | ||
num_layers=self.num_layers, bidirectional=True, batch_first=True) # Recurrent layer | ||
self.fc = nn.Linear(2*self.num_layers * | ||
self.hidden_size, config.label_num) # FC layer | ||
|
||
def forward(self, input_sentences): | ||
|
||
x = self.word_embeddings(input_sentences) # (batch_size, batch_dim, embedding_length) | ||
# print(x.size()) | ||
|
||
output, h_n = self.recurrent(x) # 特征,隐状态 | ||
# print(h_n.size()) | ||
# h_n.size() = (2*self.num_layers, batch_size, hidden_size), 2 for bidirectional | ||
|
||
# (batch_size, 2*self.num_layers, hidden_size) | ||
h_n = h_n.permute(1, 0, 2) | ||
|
||
h_n = h_n.contiguous().view(h_n.size()[0], h_n.size()[1]*h_n.size()[2]) # (batch_size, 4*hidden_size) | ||
# print(h_n.size()) | ||
|
||
logits = self.fc(h_n) # (batch_size, label_num) | ||
|
||
return F.softmax(logits, dim=1) | ||
|
||
|
||
class LSTMClassifier(nn.Module): | ||
def __init__(self, config): | ||
super(LSTMClassifier, self).__init__() | ||
|
||
self.label_num = config.label_num | ||
self.hidden_size = 150 | ||
self.embedding_length = config.wordvec_dim | ||
self.word_embeddings = nn.Embedding( | ||
400000, self.embedding_length) # Embedding layer | ||
self.word_embeddings = self.word_embeddings.from_pretrained( | ||
config.weight, freeze=False) # Load pretrianed word embedding, and fine-tuing | ||
|
||
self.lstm = nn.LSTM(self.embedding_length, | ||
self.hidden_size, batch_first=True) # lstm | ||
self.fc = nn.Linear(self.hidden_size, self.label_num) | ||
|
||
def forward(self, input_sentence): | ||
x = self.word_embeddings(input_sentence) # (batch_size, batch_dim, embedding_length) | ||
|
||
output, (final_hidden_state, final_cell_state) = self.lstm(x) | ||
|
||
logits = self.fc(final_hidden_state[-1]) # final_hidden_state.size() = (1, batch_size, hidden_size) & logits.size() = (batch_size, label_num) | ||
|
||
return F.softmax(logits, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# -*- coding: utf-8 -*- | ||
import torch | ||
import torch.autograd as autograd | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
|
||
|
||
class TextCNN(nn.Module): | ||
|
||
def __init__(self, config): | ||
super(TextCNN, self).__init__() | ||
self.kernel_num = config.kernel_num # Number of conv kernels | ||
self.embedding_length = config.wordvec_dim | ||
self.Ks = config.kernel_sizes | ||
self.word_embeddings = nn.Embedding(400000, self.embedding_length) # Embedding layer | ||
self.word_embeddings = self.word_embeddings.from_pretrained(config.weight, freeze=False) # Load pretrianed word embedding, and fine-tuing | ||
|
||
self.convs = nn.ModuleList([nn.Conv2d(1, config.kernel_num, (K, config.wordvec_dim), bias=True) for K in self.Ks]) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.dropout = nn.Dropout(0.5) | ||
self.fc = nn.Linear(len(self.Ks)*self.kernel_num, config.label_num) | ||
self.bn1 = nn.BatchNorm1d(len(self.Ks)*self.kernel_num) | ||
# self.fc1 = nn.Linear(len(self.Ks)*self.kernel_num, fc1out) | ||
# self.bn2 = nn.BatchNorm1d(fc1out) | ||
# self.fc2 = nn.Linear(fc1out, config.label_num) | ||
self.softmax = nn.Softmax(dim=1) | ||
|
||
def forward(self, x): | ||
# (batch_size, batch_dim) | ||
# Embedding | ||
x = self.word_embeddings(x) # (batch_size, batch_dim, embedding_len) | ||
|
||
x = x.unsqueeze(1) # (batch_size, 1, batch_dim, embedding_len) | ||
# print(x.size()) | ||
|
||
# Conv and relu | ||
x = [self.relu(conv(x)).squeeze(3) for conv in self.convs] # [(batch_size, kernel_num, batch_dim-K+1), ...]*len(Ks) | ||
# print([i.shape for i in x]) | ||
|
||
# max-over-time pooling (actually merely max) | ||
x = [F.max_pool1d(xi, xi.size()[2]).squeeze(2) for xi in x] # [(batch_size, kernel_num), ...]*len(Ks) | ||
# print([i.shape for i in x]) | ||
x = torch.cat(x, dim=1) # (batch_size, kernel_num*len(Ks)) | ||
|
||
|
||
x = self.bn1(x) # Batch Normaliztion | ||
x = self.dropout(x) # (batch_size, len(Ks)*kernel_num), dropout | ||
# print(x.size()) | ||
|
||
x = self.fc(x) # (batch_size, label_num) | ||
# print(x.size()) | ||
|
||
# x = self.fc1(x) | ||
# x = self.bn2(x) | ||
# x = self.fc2(x) | ||
logit = self.softmax(x) | ||
return logit | ||
|
||
if __name__ == '__main__': | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import re | ||
from torch.utils.data import Dataset, TensorDataset | ||
import torchtext.data | ||
import torch | ||
import time | ||
|
||
|
||
def clean_data(sentence): | ||
# From yoonkim: https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py | ||
sentence = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", sentence) | ||
sentence = re.sub(r"\s{2,}", " ", sentence) | ||
return sentence.strip().lower() | ||
|
||
|
||
def get_class(sentiment, num_classes): | ||
# 根据sentiment value 返回一个label | ||
return int(sentiment * (num_classes - 0.001)) | ||
|
||
|
||
def loadGloveModel(gloveFile): | ||
glove = pd.read_csv(gloveFile, sep=' ', header=None, encoding='utf-8', index_col=0, na_values=None, keep_default_na=False, quoting=3) | ||
return glove # (word, embedding), 400k*dim | ||
|
||
# with open(gloveFile, 'r', encoding='utf-8') as f: | ||
# dim = gloveFile.split('.')[2][:-1] # 维度 | ||
# model = {} | ||
# lookup_tab = np.zeros((400000, int(dim))) | ||
# for i, line in enumerate(f): | ||
# splitLine = line.split() | ||
# word = splitLine[0] # 词 | ||
# embedding = np.array([float(val) for val in splitLine[1:]]) # 词向量 | ||
# model[word] = (embedding, i) # (embedding, index) | ||
# lookup_tab[i] = embedding | ||
# return model, lookup_tab | ||
|
||
|
||
class SSTDataset(Dataset): | ||
label_tmp = None | ||
|
||
def __init__(self, path_to_dataset, name, num_classes, wordvec_dim, wordvec, device='cpu'): | ||
"""SST dataset | ||
Args: | ||
path_to_dataset (str): 路径 | ||
name (str): train, dev or test | ||
num_classes (int): 2 or 5 | ||
wordvec_dim (int): 词向量维度 | ||
wordvec (array): 词向量(GloVe) | ||
device (str, optional): 运行在的设备. Defaults to 'cpu'. | ||
""" | ||
phrase_ids = pd.read_csv(path_to_dataset + 'phrase_ids.' + | ||
name + '.txt', header=None, encoding='utf-8', dtype=int) | ||
phrase_ids = set(np.array(phrase_ids).squeeze()) # 在数据集中出现的pharse id | ||
self.num_classes = num_classes | ||
phrase_dict = {} # {phrase: id} | ||
|
||
if SSTDataset.label_tmp is None: | ||
# 先读label (sentiment) | ||
# 训练/测试/验证集共享一个,没必要读3次 | ||
SSTDataset.label_tmp = pd.read_csv(path_to_dataset + 'sentiment_labels.txt', | ||
sep='|', dtype={'phrase ids': int, 'sentiment values': float}) | ||
SSTDataset.label_tmp = np.array(SSTDataset.label_tmp)[:, 1:] # sentiment value | ||
|
||
with open(path_to_dataset + 'dictionary.txt', 'r', encoding='utf-8') as f: | ||
i = 0 | ||
for line in f: | ||
phrase, phrase_id = line.strip().split('|') | ||
if int(phrase_id) in phrase_ids: # 在数据集中出现 | ||
phrase = clean_data(phrase) # 预处理 | ||
phrase_dict[int(phrase_id)] = phrase | ||
i += 1 | ||
f.close() | ||
|
||
|
||
# 记录每个句子中单词在glove中的index | ||
self.phrase_vec = [] | ||
|
||
# 每个句子的label | ||
self.labels = torch.zeros((len(phrase_dict),), dtype=torch.long) | ||
|
||
i = 0 | ||
missing_count = 0 | ||
# 查找每个句子中词的词向量 | ||
for idx, p in phrase_dict.items(): | ||
tmp1 = [] # 暂存句子中单词的id | ||
# 分词 | ||
for w in p.split(' '): | ||
try: | ||
tmp1.append(wordvec.index.get_loc(w)) # 单词w在glove中的index | ||
except KeyError: | ||
missing_count += 1 | ||
|
||
self.phrase_vec.append(torch.tensor(tmp1, dtype=torch.long)) # 包含句子中每个词的glove index | ||
self.labels[i] = get_class(SSTDataset.label_tmp[idx], self.num_classes) # pos i 的句子的label | ||
i += 1 | ||
|
||
print(missing_count) | ||
|
||
def __getitem__(self, index): | ||
return self.phrase_vec[index], self.labels[index] | ||
|
||
def __len__(self): | ||
return len(self.phrase_vec) | ||
|
||
|
||
if __name__ == "__main__": | ||
# test | ||
wordvec = loadGloveModel('data/glove/glove.6B.'+ str(50) +'d.txt') | ||
test = SSTDataset('data/dataset/', 'test', 2, 50, wordvec) |
Oops, something went wrong.