Skip to content

Commit

Permalink
mnist compare to pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
borgwang committed Nov 15, 2021
1 parent 08d8692 commit c6fb655
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 10 deletions.
164 changes: 164 additions & 0 deletions examples/mnist/pytorch-run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import argparse
import os
import time

import tinynn as tn

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Dense(nn.Module):

def __init__(self):
super(Dense, self).__init__()
self.fc1 = nn.Linear(784, 200)
self.fc2 = nn.Linear(200, 100)
self.fc3 = nn.Linear(100, 70)
self.fc4 = nn.Linear(70, 30)
self.fc5 = nn.Linear(30, 10)
torch.nn.init.xavier_uniform_(self.fc1.weight)
torch.nn.init.xavier_uniform_(self.fc2.weight)
torch.nn.init.xavier_uniform_(self.fc3.weight)
torch.nn.init.xavier_uniform_(self.fc4.weight)
torch.nn.init.xavier_uniform_(self.fc5.weight)

def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = self.fc5(x)
x = F.log_softmax(x, dim=1)
return x


class Conv(nn.Module):

def __init__(self):
super(Conv, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, 1, padding="same")
self.conv2 = nn.Conv2d(6, 16, 5, 1, padding="same")

self.fc1 = nn.Linear(784, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)

x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)

x = torch.flatten(x, 1)

x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)

x = F.log_softmax(x, dim=1)
return x


class RNN(nn.Module):

def __init__(self):
super(RNN, self).__init__()
self.recurrent = nn.RNN(28, 30, batch_first=True)
self.fc1 = nn.Linear(30, 10)

def forward(self, x):
output, hidden = self.recurrent(x)
x = output[:, -1]
x = self.fc1(x)
x = F.log_softmax(x, dim=1)
return x


class LSTM(RNN):

def __init__(self):
super(LSTM, self).__init__()
self.recurrent = nn.LSTM(28, 30, batch_first=True)
self.fc1 = nn.Linear(30, 10)


def main():
if args.seed >= 0:
tn.seeder.random_seed(args.seed)
torch.manual_seed(args.seed)

mnist = tn.dataset.MNIST(args.data_dir, one_hot=False)
train_x, train_y = mnist.train_set
test_x, test_y = mnist.test_set

if args.model_type == "mlp":
model = Dense()
elif args.model_type == "cnn":
train_x = train_x.reshape((-1, 1, 28, 28))
test_x = test_x.reshape((-1, 1, 28, 28))
model = Conv()
elif args.model_type == "rnn":
train_x = train_x.reshape((-1, 28, 28))
test_x = test_x.reshape((-1, 28, 28))
model = RNN()
elif args.model_type == "lstm":
train_x = train_x.reshape((-1, 28, 28))
test_x = test_x.reshape((-1, 28, 28))
model = LSTM()

model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

model.train()
iterator = tn.data_iterator.BatchIterator(batch_size=args.batch_size)
for epoch in range(args.num_ep):
t_start = time.time()
f_cost, b_cost = 0, 0
for batch in iterator(train_x, train_y):
x = torch.from_numpy(batch.inputs).to(device)
y = torch.from_numpy(batch.targets).to(device)
optimizer.zero_grad()
pred = model(x)
loss = F.nll_loss(pred, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} time cost: {time.time() - t_start}")
# evaluate
evaluate(model, test_x, test_y)


def evaluate(model, test_x, test_y):
model.eval()
x, y = torch.from_numpy(test_x).to(device), torch.from_numpy(test_y).to(device)
with torch.no_grad():
pred = model(x)
test_pred_idx = pred.argmax(dim=1).numpy()
accuracy, info = tn.metric.accuracy(test_pred_idx, test_y)
print(f"accuracy: {accuracy:.4f} info: {info}")


if __name__ == "__main__":
curr_dir = os.path.dirname(os.path.abspath(__file__))

parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str,
default=os.path.join(curr_dir, "data"))
parser.add_argument("--model_type", default="mlp", type=str,
help="[*mlp|cnn|rnn|lstm]")
parser.add_argument("--num_ep", default=10, type=int)
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--seed", default=31, type=int)
args = parser.parse_args()

device = torch.device("cpu")

main()
20 changes: 10 additions & 10 deletions tinynn/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_set(self):
return self._test_set

@staticmethod
def get_one_hot(targets, n_classes):
def one_hot(targets, n_classes):
return np.eye(n_classes, dtype=np.float32)[np.array(targets).reshape(-1)]


Expand All @@ -63,9 +63,9 @@ def _parse(self, **kwargs):
train, valid, test = pickle.load(f, encoding="latin1")

if kwargs["one_hot"]:
train = (train[0], self.get_one_hot(train[1], self._n_classes))
valid = (valid[0], self.get_one_hot(valid[1], self._n_classes))
test = (test[0], self.get_one_hot(test[1], self._n_classes))
train = (train[0], self.one_hot(train[1], self._n_classes))
valid = (valid[0], self.one_hot(valid[1], self._n_classes))
test = (test[0], self.one_hot(test[1], self._n_classes))

self._train_set, self._valid_set, self._test_set = train, valid, test

Expand Down Expand Up @@ -103,8 +103,8 @@ def _parse(self, **kwargs):
test_x = test_x.reshape((len(test_x), -1))

if kwargs["one_hot"]:
train_y = self.get_one_hot(train_y, self._n_classes)
test_y = self.get_one_hot(test_y, self._n_classes)
train_y = self.one_hot(train_y, self._n_classes)
test_y = self.one_hot(test_y, self._n_classes)
self._train_set = (train_x, train_y)
self._test_set = (test_x, test_y)

Expand Down Expand Up @@ -164,8 +164,8 @@ def _parse(self, **kwargs):
test_x = self._cifar_normalize(test_x)

if kwargs["one_hot"]:
train_y = self.get_one_hot(train_y, self._n_classes)
test_y = self.get_one_hot(test_y, self._n_classes)
train_y = self.one_hot(train_y, self._n_classes)
test_y = self.one_hot(test_y, self._n_classes)

self._train_set = (train_x, train_y)
self._test_set = (test_x, test_y)
Expand All @@ -192,8 +192,8 @@ def _parse(self, **kwargs):
test_x = self._cifar_normalize(test_x)

if kwargs["one_hot"]:
train_y = self.get_one_hot(train_y, self._n_classes)
test_y = self.get_one_hot(test_y, self._n_classes)
train_y = self.one_hot(train_y, self._n_classes)
test_y = self.one_hot(test_y, self._n_classes)

self._train_set = (train_x, train_y)
self._test_set = (test_x, test_y)

0 comments on commit c6fb655

Please sign in to comment.