1
+ from __future__ import print_function
2
+ import torch
3
+ import torch .multiprocessing as mp
4
+ from torch .multiprocessing import Barrier
5
+ from torchvision import datasets , transforms
6
+ from torch .utils .data import Subset
7
+ import os
8
+ import torch
9
+ import torch .optim as optim
10
+ import torch .nn .functional as F
11
+ import torch .nn as nn
12
+ import torch .nn .functional as F
13
+
14
+
15
+ batch_size = 64 # input batch size for training
16
+ test_batch_size = 1000 # input batch size for testing
17
+ epochs = 3 # number of global epochs to train
18
+ lr = 0.01 # learning rate
19
+ momentum = 0.5 # SGD momentum
20
+ seed = 1 # random seed
21
+ log_interval = 10 # how many batches to wait before logging training status
22
+ n_workers = 4 # how many training processes to use
23
+ cuda = True # enables CUDA training
24
+ mps = False # enables macOS GPU training
25
+
26
+
27
+ class CustomSubset (Subset ):
28
+ '''A custom subset class with customizable data transformation'''
29
+ def __init__ (self , dataset , indices , subset_transform = None ):
30
+ super ().__init__ (dataset , indices )
31
+ self .subset_transform = subset_transform
32
+
33
+ def __getitem__ (self , idx ):
34
+ x , y = self .dataset [self .indices [idx ]]
35
+ if self .subset_transform :
36
+ x = self .subset_transform (x )
37
+ return x , y
38
+
39
+ def __len__ (self ):
40
+ return len (self .indices )
41
+
42
+
43
+ def dataset_split (dataset , n_workers ):
44
+ n_samples = len (dataset )
45
+ n_sample_per_workers = n_samples // n_workers
46
+ local_datasets = []
47
+ for w_id in range (n_workers ):
48
+ if w_id < n_workers - 1 :
49
+ local_datasets .append (CustomSubset (dataset , range (w_id * n_sample_per_workers , (w_id + 1 ) * n_sample_per_workers )))
50
+ else :
51
+ local_datasets .append (CustomSubset (dataset , range (w_id * n_sample_per_workers , n_samples )))
52
+ return local_datasets
53
+
54
+
55
+ def pull_down (global_W , local_Ws , n_workers ):
56
+ # pull down global model to local
57
+ for rank in range (n_workers ):
58
+ for name , value in local_Ws [rank ].items ():
59
+ local_Ws [rank ][name ].data = global_W [name ].data
60
+
61
+
62
+ def aggregate (global_W , local_Ws , n_workers ):
63
+ # init the global model
64
+ for name , value in global_W .items ():
65
+ global_W [name ].data = torch .zeros_like (value )
66
+
67
+ for rank in range (n_workers ):
68
+ for name , value in local_Ws [rank ].items ():
69
+ global_W [name ].data += value .data
70
+
71
+ for name in local_Ws [rank ].keys ():
72
+ global_W [name ].data /= n_workers
73
+
74
+
75
+ class Net (nn .Module ):
76
+ def __init__ (self ):
77
+ super (Net , self ).__init__ ()
78
+ self .conv1 = nn .Conv2d (1 , 10 , kernel_size = 5 )
79
+ self .conv2 = nn .Conv2d (10 , 20 , kernel_size = 5 )
80
+ self .conv2_drop = nn .Dropout2d ()
81
+ self .fc1 = nn .Linear (320 , 50 )
82
+ self .fc2 = nn .Linear (50 , 10 )
83
+
84
+ def forward (self , x ):
85
+ x = F .relu (F .max_pool2d (self .conv1 (x ), 2 ))
86
+ x = F .relu (F .max_pool2d (self .conv2_drop (self .conv2 (x )), 2 ))
87
+ x = x .view (- 1 , 320 )
88
+ x = F .relu (self .fc1 (x ))
89
+ x = F .dropout (x , training = self .training )
90
+ x = self .fc2 (x )
91
+ return F .log_softmax (x , dim = 1 )
92
+
93
+
94
+ def train_epoch (epoch , rank , local_model , device , dataset , synchronizer , dataloader_kwargs ):
95
+ torch .manual_seed (seed + rank )
96
+ train_loader = torch .utils .data .DataLoader (dataset , ** dataloader_kwargs )
97
+ optimizer = optim .SGD (local_model .parameters (), lr = lr , momentum = momentum )
98
+
99
+ local_model .train ()
100
+ pid = os .getpid ()
101
+ for batch_idx , (data , target ) in enumerate (train_loader ):
102
+ optimizer .zero_grad ()
103
+ output = local_model (data .to (device ))
104
+ loss = F .nll_loss (output , target .to (device ))
105
+ loss .backward ()
106
+ optimizer .step ()
107
+ if batch_idx % log_interval == 0 :
108
+ print ('{}\t Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
109
+ pid , epoch + 1 , batch_idx * len (data ), len (train_loader .dataset ),
110
+ 100. * batch_idx / len (train_loader ), loss .item ()))
111
+
112
+ # synchronizer.wait()
113
+
114
+
115
+ def test (epoch , model , device , dataset , dataloader_kwargs ):
116
+ torch .manual_seed (seed )
117
+ test_loader = torch .utils .data .DataLoader (dataset , ** dataloader_kwargs )
118
+
119
+ model .eval ()
120
+ test_loss = 0
121
+ correct = 0
122
+ with torch .no_grad ():
123
+ for data , target in test_loader :
124
+ output = model (data .to (device ))
125
+ test_loss += F .nll_loss (output , target .to (device ), reduction = 'sum' ).item () # sum up batch loss
126
+ pred = output .max (1 )[1 ] # get the index of the max log-probability
127
+ correct += pred .eq (target .to (device )).sum ().item ()
128
+
129
+ test_loss /= len (test_loader .dataset )
130
+ print ('\n Test Epoch: {} Global loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
131
+ epoch + 1 , test_loss , correct , len (test_loader .dataset ),
132
+ 100. * correct / len (test_loader .dataset )))
133
+
134
+
135
+ if __name__ == "__main__" :
136
+ use_cuda = cuda and torch .cuda .is_available ()
137
+ use_mps = mps and torch .backends .mps .is_available ()
138
+ if use_cuda :
139
+ device = torch .device ("cuda" )
140
+ elif use_mps :
141
+ device = torch .device ("mps" )
142
+ else :
143
+ device = torch .device ("cpu" )
144
+
145
+ transform = transforms .Compose ([
146
+ transforms .ToTensor (),
147
+ transforms .Normalize ((0.1307 ,), (0.3081 ,))
148
+ ])
149
+ train_dataset = datasets .MNIST ('./data' , train = True , download = True ,
150
+ transform = transform )
151
+ test_dataset = datasets .MNIST ('./data' , train = False , download = True ,
152
+ transform = transform )
153
+ local_train_datasets = dataset_split (train_dataset , n_workers )
154
+
155
+ kwargs = {'batch_size' : batch_size ,
156
+ 'shuffle' : True }
157
+ if use_cuda :
158
+ kwargs .update ({'num_workers' : 1 , # num_workers to load data
159
+ 'pin_memory' : True ,
160
+ })
161
+
162
+ torch .manual_seed (seed )
163
+ mp .set_start_method ('spawn' , force = True )
164
+ # Very important, otherwise CUDA memory cannot be allocated in the child process
165
+
166
+ local_models = [Net ().to (device ) for i in range (n_workers )]
167
+ global_model = Net ().to (device )
168
+ local_Ws = [{key : value for key , value in local_models [i ].named_parameters ()} for i in range (n_workers )]
169
+ global_W = {key : value for key , value in global_model .named_parameters ()}
170
+
171
+ synchronizer = Barrier (n_workers )
172
+ for epoch in range (epochs ):
173
+ for rank in range (n_workers ):
174
+ # pull down global model to local
175
+ pull_down (global_W , local_Ws , n_workers )
176
+
177
+ processes = []
178
+ for rank in range (n_workers ):
179
+ p = mp .Process (target = train_epoch , args = (epoch , rank , local_models [rank ], device ,
180
+ local_train_datasets [rank ], synchronizer , kwargs ))
181
+ # We first train the model across `num_processes` processes
182
+ p .start ()
183
+ processes .append (p )
184
+
185
+ for p in processes :
186
+ p .join ()
187
+
188
+ aggregate (global_W , local_Ws , n_workers )
189
+
190
+ # We test the model each epoch
191
+ test (epoch , global_model , device , test_dataset , kwargs )
192
+ # Test result for synchronous training:Test Epoch: 3 Global loss: 0.0732, Accuracy: 9796/10000 (98%)
193
+ # Test result for asynchronous training:Test Epoch: 3 Global loss: 0.0742, Accuracy: 9789/10000 (98%)
194
+
0 commit comments