Skip to content

Commit ed9ac95

Browse files
committed
add ssgd_pytorch
1 parent 28daba9 commit ed9ac95

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

optimization/ssgd_pytorch.py

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import print_function
2+
import torch
3+
import torch.multiprocessing as mp
4+
from torch.multiprocessing import Barrier
5+
from torchvision import datasets, transforms
6+
from torch.utils.data import Subset
7+
import os
8+
import torch
9+
import torch.optim as optim
10+
import torch.nn.functional as F
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
14+
15+
batch_size = 64 # input batch size for training
16+
test_batch_size = 1000 # input batch size for testing
17+
epochs = 3 # number of global epochs to train
18+
lr = 0.01 # learning rate
19+
momentum = 0.5 # SGD momentum
20+
seed = 1 # random seed
21+
log_interval = 10 # how many batches to wait before logging training status
22+
n_workers = 4 # how many training processes to use
23+
cuda = True # enables CUDA training
24+
mps = False # enables macOS GPU training
25+
26+
27+
class CustomSubset(Subset):
28+
'''A custom subset class with customizable data transformation'''
29+
def __init__(self, dataset, indices, subset_transform=None):
30+
super().__init__(dataset, indices)
31+
self.subset_transform = subset_transform
32+
33+
def __getitem__(self, idx):
34+
x, y = self.dataset[self.indices[idx]]
35+
if self.subset_transform:
36+
x = self.subset_transform(x)
37+
return x, y
38+
39+
def __len__(self):
40+
return len(self.indices)
41+
42+
43+
def dataset_split(dataset, n_workers):
44+
n_samples = len(dataset)
45+
n_sample_per_workers = n_samples // n_workers
46+
local_datasets = []
47+
for w_id in range(n_workers):
48+
if w_id < n_workers - 1:
49+
local_datasets.append(CustomSubset(dataset, range(w_id * n_sample_per_workers, (w_id + 1) * n_sample_per_workers)))
50+
else:
51+
local_datasets.append(CustomSubset(dataset, range(w_id * n_sample_per_workers, n_samples)))
52+
return local_datasets
53+
54+
55+
def pull_down(global_W, local_Ws, n_workers):
56+
# pull down global model to local
57+
for rank in range(n_workers):
58+
for name, value in local_Ws[rank].items():
59+
local_Ws[rank][name].data = global_W[name].data
60+
61+
62+
def aggregate(global_W, local_Ws, n_workers):
63+
# init the global model
64+
for name, value in global_W.items():
65+
global_W[name].data = torch.zeros_like(value)
66+
67+
for rank in range(n_workers):
68+
for name, value in local_Ws[rank].items():
69+
global_W[name].data += value.data
70+
71+
for name in local_Ws[rank].keys():
72+
global_W[name].data /= n_workers
73+
74+
75+
class Net(nn.Module):
76+
def __init__(self):
77+
super(Net, self).__init__()
78+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
79+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
80+
self.conv2_drop = nn.Dropout2d()
81+
self.fc1 = nn.Linear(320, 50)
82+
self.fc2 = nn.Linear(50, 10)
83+
84+
def forward(self, x):
85+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
86+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
87+
x = x.view(-1, 320)
88+
x = F.relu(self.fc1(x))
89+
x = F.dropout(x, training=self.training)
90+
x = self.fc2(x)
91+
return F.log_softmax(x, dim=1)
92+
93+
94+
def train_epoch(epoch, rank, local_model, device, dataset, synchronizer, dataloader_kwargs):
95+
torch.manual_seed(seed + rank)
96+
train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
97+
optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum)
98+
99+
local_model.train()
100+
pid = os.getpid()
101+
for batch_idx, (data, target) in enumerate(train_loader):
102+
optimizer.zero_grad()
103+
output = local_model(data.to(device))
104+
loss = F.nll_loss(output, target.to(device))
105+
loss.backward()
106+
optimizer.step()
107+
if batch_idx % log_interval == 0:
108+
print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
109+
pid, epoch + 1, batch_idx * len(data), len(train_loader.dataset),
110+
100. * batch_idx / len(train_loader), loss.item()))
111+
112+
# synchronizer.wait()
113+
114+
115+
def test(epoch, model, device, dataset, dataloader_kwargs):
116+
torch.manual_seed(seed)
117+
test_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
118+
119+
model.eval()
120+
test_loss = 0
121+
correct = 0
122+
with torch.no_grad():
123+
for data, target in test_loader:
124+
output = model(data.to(device))
125+
test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss
126+
pred = output.max(1)[1] # get the index of the max log-probability
127+
correct += pred.eq(target.to(device)).sum().item()
128+
129+
test_loss /= len(test_loader.dataset)
130+
print('\nTest Epoch: {} Global loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
131+
epoch + 1, test_loss, correct, len(test_loader.dataset),
132+
100. * correct / len(test_loader.dataset)))
133+
134+
135+
if __name__ == "__main__":
136+
use_cuda = cuda and torch.cuda.is_available()
137+
use_mps = mps and torch.backends.mps.is_available()
138+
if use_cuda:
139+
device = torch.device("cuda")
140+
elif use_mps:
141+
device = torch.device("mps")
142+
else:
143+
device = torch.device("cpu")
144+
145+
transform=transforms.Compose([
146+
transforms.ToTensor(),
147+
transforms.Normalize((0.1307,), (0.3081,))
148+
])
149+
train_dataset = datasets.MNIST('./data', train=True, download=True,
150+
transform=transform)
151+
test_dataset = datasets.MNIST('./data', train=False, download=True,
152+
transform=transform)
153+
local_train_datasets = dataset_split(train_dataset, n_workers)
154+
155+
kwargs = {'batch_size': batch_size,
156+
'shuffle': True}
157+
if use_cuda:
158+
kwargs.update({'num_workers': 1, # num_workers to load data
159+
'pin_memory': True,
160+
})
161+
162+
torch.manual_seed(seed)
163+
mp.set_start_method('spawn', force=True)
164+
# Very important, otherwise CUDA memory cannot be allocated in the child process
165+
166+
local_models = [Net().to(device) for i in range(n_workers)]
167+
global_model = Net().to(device)
168+
local_Ws = [{key: value for key, value in local_models[i].named_parameters()} for i in range(n_workers)]
169+
global_W = {key: value for key, value in global_model.named_parameters()}
170+
171+
synchronizer = Barrier(n_workers)
172+
for epoch in range(epochs):
173+
for rank in range(n_workers):
174+
# pull down global model to local
175+
pull_down(global_W, local_Ws, n_workers)
176+
177+
processes = []
178+
for rank in range(n_workers):
179+
p = mp.Process(target=train_epoch, args=(epoch, rank, local_models[rank], device,
180+
local_train_datasets[rank], synchronizer, kwargs))
181+
# We first train the model across `num_processes` processes
182+
p.start()
183+
processes.append(p)
184+
185+
for p in processes:
186+
p.join()
187+
188+
aggregate(global_W, local_Ws, n_workers)
189+
190+
# We test the model each epoch
191+
test(epoch, global_model, device, test_dataset, kwargs)
192+
# Test result for synchronous training:Test Epoch: 3 Global loss: 0.0732, Accuracy: 9796/10000 (98%)
193+
# Test result for asynchronous training:Test Epoch: 3 Global loss: 0.0742, Accuracy: 9789/10000 (98%)
194+

0 commit comments

Comments
 (0)