-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist.py
67 lines (47 loc) · 1.72 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Ye Liu. Licensed under the MIT License.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from nncore.engine import Engine
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# yapf:disable
self.backbone = nn.Sequential(
nn.Conv2d(1, 6, 5),
nn.Tanh(),
nn.AvgPool2d(2),
nn.Conv2d(6, 16, 5),
nn.Tanh(),
nn.AvgPool2d(2),
nn.Conv2d(16, 120, 5))
self.head = nn.Sequential(
nn.Linear(120, 84),
nn.Tanh(),
nn.Linear(84, 10))
# yapf:enable
self.loss = nn.CrossEntropyLoss()
def forward(self, data, **kwargs):
x, y = data[0], data[1]
x = self.backbone(x).squeeze()
x = self.head(x)
out = torch.argmax(x, dim=1)
acc = torch.eq(out, y).sum().float() / x.size(0)
loss = self.loss(x, y)
return dict(_avg_factor=x.size(0), acc=acc, loss=loss)
def main():
# Prepare datasets and the model
transform = Compose([ToTensor(), Resize(32), Normalize(0.5, 0.5)])
train = MNIST('data', train=True, transform=transform, download=True)
train_loader = DataLoader(train, batch_size=16, shuffle=True)
val = MNIST('data', train=False, transform=transform, download=True)
val_loader = DataLoader(val, batch_size=64, shuffle=False)
data_loaders = dict(train=train_loader, val=val_loader)
model = LeNet()
# Initialize and launch engine
engine = Engine(model, data_loaders)
engine.launch()
if __name__ == '__main__':
main()