Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sherrylone patch 2 #499

Merged
merged 3 commits into from
May 2, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
ch20 Level5_Pytorch.py
Add GRU model
  • Loading branch information
Sherrylone authored Apr 30, 2020
commit b3ae6c83c71248402afd3ca0b8a7d1d7af6f664d
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn as nn
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.init import xavier_normal


train_file = "../../data/ch19.train_minus.npz"
Expand All @@ -34,10 +35,36 @@ def process_data(dr): # 统一输入单词的长度,不足补0

return dr.XTrain, dr.XTest, dr.YTrain, dr.YTest # 这是为了统一一个表达方式,和之前的章节一样

def weights_init(m):
classname=m.__class__.__name__
if classname.find('Conv') != -1:
xavier_normal(m.weight.data)
xavier_normal(m.bias.data)

class RNN(nn.Module):
class LSTM(nn.Module):
def __init__(self):
super(RNN, self).__init__()
super(LSTM, self).__init__()
self.rnn = nn.LSTM(
input_size=2, # character num.
hidden_size=4, # RNN or LSTM hidden layer, 设置的稍大一些可能效果更佳,此处仅作对比
num_layers=1,
batch_first=True,
bidirectional=True, # 双向LSTM, 若设置为False,对应hidden_size增大两倍

)
self.softmax = nn.Softmax() # classification, softmax
self.fc = nn.Linear(8, 2)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None) # 多对多的序列 (2, 4, 8)
out = self.fc(r_out) # 0和1是二分类问题 (2, 4, 2)
out = self.sigmoid(out)
return out

class GRU(nn.Module):
def __init__(self):
super(GRU, self).__init__()
self.rnn = nn.LSTM(
input_size=2, # character num.
hidden_size=4, # RNN or LSTM hidden layer, 设置的稍大一些可能效果更佳,此处仅作对比
Expand Down Expand Up @@ -77,7 +104,9 @@ def accracy_score(pred, label):
max_epoch = 100 # hyper-parameters
lr = 1e-2
batch_size = 2
rnn = RNN()
# rnn = LSTM() # LSTM model
rnn = GRU() # GRU model
rnn.apply(weights_init)

# Data processing
dataReader = load_data()
Expand Down