Skip to content

Commit cecd049

Browse files
committed
2-1 mnist upgrade
1 parent 186f48d commit cecd049

File tree

5 files changed

+7469
-0
lines changed

5 files changed

+7469
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Block(nn.Module):
6+
7+
def __init__(self,
8+
input_size,
9+
output_size,
10+
use_batch_norm=True,
11+
dropout_p=.4):
12+
self.input_size = input_size
13+
self.output_size = output_size
14+
self.use_batch_norm = use_batch_norm
15+
self.dropout_p = dropout_p
16+
17+
super().__init__()
18+
19+
def get_regularizer(use_batch_norm, size):
20+
return nn.BatchNorm1d(size) if use_batch_norm else nn.Dropout(dropout_p)
21+
22+
self.block = nn.Sequential(
23+
nn.Linear(input_size, output_size),
24+
nn.LeakyReLU(),
25+
get_regularizer(use_batch_norm, output_size),
26+
)
27+
28+
def forward(self, x):
29+
# |x| = (batch_size, input_size)
30+
y = self.block(x)
31+
# |y| = (batch_size, output_size)
32+
33+
return y
34+
35+
36+
class ImageClassifier(nn.Module):
37+
38+
def __init__(self,
39+
input_size,
40+
output_size,
41+
hidden_sizes=[500, 400, 300, 200, 100],
42+
use_batch_norm=True,
43+
dropout_p=.3):
44+
45+
super().__init__()
46+
47+
assert len(hidden_sizes) > 0, "You need to specify hidden layers"
48+
49+
last_hidden_size = input_size
50+
blocks = []
51+
for hidden_size in hidden_sizes:
52+
blocks += [Block(
53+
last_hidden_size,
54+
hidden_size,
55+
use_batch_norm,
56+
dropout_p
57+
)]
58+
last_hidden_size = hidden_size
59+
60+
self.layers = nn.Sequential(
61+
*blocks,
62+
nn.Linear(last_hidden_size, output_size),
63+
nn.LogSoftmax(dim=-1),
64+
)
65+
66+
def forward(self, x):
67+
# |x| = (batch_size, input_size)
68+
y = self.layers(x)
69+
# |y| = (batch_size, output_size)
70+
71+
return y

0 commit comments

Comments
 (0)