forked from olehb/pytorch_ddp_tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_pnetcdf_cpu.py
156 lines (126 loc) · 5.06 KB
/
mnist_pnetcdf_cpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Tuple
import torch
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torch import nn, optim
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import os
import pncpy
import struct
from array import array
from mpi4py import MPI
DISABLE_TQDM = True
# torch dataloader
class MNISTNetCDF(Dataset):
def __init__(self, root_dir, is_train=True, transforms=None, comm=None):
if is_train:
labels_filepath = os.path.join(root_dir,'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
else:
labels_filepath = os.path.join(root_dir,'t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')
self.transforms = transforms
print ('=> Reading NetCDF File...')
nc_path = os.path.join(root_dir,'mnist_{}_images.nc'.format('train' if is_train else 'test'))
self.nc = pncpy.File(nc_path,'r', comm = comm)
print('=> Dataset created, image nc file is : {}'.format(nc_path))
def __len__(self):
return self.nc.variables['images'].shape[0]
def __getitem__(self,index):
# read image
# image = np.array(self.nc.variables['images'][index])
image = np.array(self.nc.variables['images'][index])
# fetch and encode label
buff = np.empty((), np.uint8)
self.nc.variables['labels'].get_var_all(buff, index = (index,))
if self.transforms:
image = self.transforms(image)
return image,buff
def create_data_loaders(dataset, batch_size: int, num_worker: int,) -> Tuple[DataLoader, DataLoader]:
loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_worker,
persistent_workers=True if num_worker > 0 else False)
# This is not necessary to use distributed sampler for the test or validation sets.
return loader
def create_model():
# create model architecture
model = nn.Sequential(
nn.Linear(28*28, 128), # MNIST images are 28x28 pixels
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10, bias=False) # 10 classes to predict
)
return model
def main(epochs: int,
model: nn.Module,
train_loader: DataLoader,
test_loader: DataLoader):
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()
# train the model
for i in range(epochs):
model.train()
epoch_loss = 0
# train the model for one epoch
pbar = tqdm(train_loader)
for x, y in pbar:
x = x.view(x.shape[0], -1)
optimizer.zero_grad()
y_hat = model(x)
batch_loss = loss(y_hat, y)
batch_loss.backward()
optimizer.step()
batch_loss_scalar = batch_loss.item()
epoch_loss += batch_loss_scalar / x.shape[0]
pbar.set_description(f'training batch_loss={batch_loss_scalar:.4f}')
# calculate validation loss
with torch.no_grad():
model.eval()
val_loss = 0
pbar = tqdm(test_loader)
for x, y in pbar:
x = x.view(x.shape[0], -1)
y_hat = model(x)
batch_loss = loss(y_hat, y)
batch_loss_scalar = batch_loss.item()
val_loss += batch_loss_scalar / x.shape[0]
pbar.set_description(f'validation batch_loss={batch_loss_scalar:.4f}')
print(f"Epoch={i}, train_loss={epoch_loss:.4f}, val_loss={val_loss:.4f}")
if __name__ == '__main__':
comm = MPI.COMM_WORLD
batch_size = 128
epochs = 1
num_worker = 4
dataset_loc = '.'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNISTNetCDF(root_dir=dataset_loc,
is_train=True,
transforms=transform,
comm = comm)
test_dataset = MNISTNetCDF(root_dir=dataset_loc,
is_train=False,
transforms=transform,
comm = comm)
# train_loader = create_data_loaders(train_dataset, batch_size, 2)
# test_loader = create_data_loaders(test_dataset, batch_size, 2)
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_worker,
persistent_workers=True if num_worker > 0 else False)
test_loader = DataLoader(test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_worker,
persistent_workers=True if num_worker > 0 else False)
main(epochs=epochs,
model=create_model(),
train_loader=train_loader,
test_loader=test_loader)