Skip to content

Commit 30c498f

Browse files
committed
fix bug
1 parent 8988438 commit 30c498f

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

05-Recurrent Neural Network/recurrent_network.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
num_epoches = 20
1414

1515
# 下载训练集 MNIST 手写数字训练集
16-
train_dataset = datasets.MNIST(root='./data', train=True,
17-
transform=transforms.ToTensor(),
18-
download=True)
16+
train_dataset = datasets.MNIST(
17+
root='./data', train=True, transform=transforms.ToTensor(), download=True)
1918

20-
test_dataset = datasets.MNIST(root='./data', train=False,
21-
transform=transforms.ToTensor())
19+
test_dataset = datasets.MNIST(
20+
root='./data', train=False, transform=transforms.ToTensor())
2221

2322
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
2423
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
@@ -30,15 +29,14 @@ def __init__(self, in_dim, hidden_dim, n_layer, n_class):
3029
super(Rnn, self).__init__()
3130
self.n_layer = n_layer
3231
self.hidden_dim = hidden_dim
33-
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,
34-
batch_first=True)
32+
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
3533
self.classifier = nn.Linear(hidden_dim, n_class)
3634

3735
def forward(self, x):
3836
# h0 = Variable(torch.zeros(self.n_layer, x.size(1),
39-
# self.hidden_dim)).cuda()
37+
# self.hidden_dim)).cuda()
4038
# c0 = Variable(torch.zeros(self.n_layer, x.size(1),
41-
# self.hidden_dim)).cuda()
39+
# self.hidden_dim)).cuda()
4240
out, _ = self.lstm(x)
4341
out = out[:, -1, :]
4442
out = self.classifier(out)
@@ -55,8 +53,8 @@ def forward(self, x):
5553

5654
# 开始训练
5755
for epoch in range(num_epoches):
58-
print('epoch {}'.format(epoch+1))
59-
print('*'*10)
56+
print('epoch {}'.format(epoch + 1))
57+
print('*' * 10)
6058
running_loss = 0.0
6159
running_acc = 0.0
6260
for i, data in enumerate(train_loader, 1):
@@ -87,15 +85,11 @@ def forward(self, x):
8785

8886
if i % 300 == 0:
8987
print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(
90-
epoch+1, num_epoches,
91-
running_loss/(batch_size*i),
92-
running_acc/(batch_size*i)
93-
))
88+
epoch + 1, num_epoches, running_loss / (batch_size * i),
89+
running_acc / (batch_size * i)))
9490
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
95-
epoch+1,
96-
running_loss/(len(train_dataset)),
97-
running_acc/(len(train_dataset))
98-
))
91+
epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(
92+
train_dataset))))
9993
model.eval()
10094
eval_loss = 0
10195
eval_acc = 0
@@ -111,18 +105,16 @@ def forward(self, x):
111105
img = Variable(img, volatile=True).cuda()
112106
label = Variable(label, volatile=True).cuda()
113107
else:
114-
img = Variabel(img, volatile=True)
108+
img = Variable(img, volatile=True)
115109
label = Variable(label, volatile=True)
116110
out = model(img)
117111
loss = criterion(out, label)
118-
eval_loss += loss.data[0]*label.size(0)
112+
eval_loss += loss.data[0] * label.size(0)
119113
_, pred = torch.max(out, 1)
120114
num_correct = (pred == label).sum()
121115
eval_acc += num_correct.data[0]
122-
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(
123-
eval_loss/(len(test_dataset)),
124-
eval_acc/(len(test_dataset))
125-
))
116+
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
117+
test_dataset)), eval_acc / (len(test_dataset))))
126118
print()
127119

128120
# 保存模型

0 commit comments

Comments
 (0)