@@ -39,20 +39,17 @@ def train_epoch(model, loader, optimizer, epoch, n_epochs, print_freq=1):
39
39
for batch_idx , (input , target ) in enumerate (loader ):
40
40
# Create vaiables
41
41
if torch .cuda .is_available ():
42
- input_var = torch .autograd .Variable (input .cuda (async = True ))
43
- target_var = torch .autograd .Variable (target .cuda (async = True ))
44
- else :
45
- input_var = torch .autograd .Variable (input )
46
- target_var = torch .autograd .Variable (target )
42
+ input = input .cuda ()
43
+ target = target .cuda ()
47
44
48
45
# compute output
49
- output = model (input_var )
50
- loss = torch .nn .functional .cross_entropy (output , target_var )
46
+ output = model (input )
47
+ loss = torch .nn .functional .cross_entropy (output , target )
51
48
52
49
# measure accuracy and record loss
53
50
batch_size = target .size (0 )
54
51
_ , pred = output .data .cpu ().topk (1 , dim = 1 )
55
- error .update (torch .ne (pred .squeeze (), target .cpu ()).float ().sum () / batch_size , batch_size )
52
+ error .update (torch .ne (pred .squeeze (), target .cpu ()).float ().sum (). item () / batch_size , batch_size )
56
53
losses .update (loss .item (), batch_size )
57
54
58
55
# compute gradient and do SGD step
@@ -88,70 +85,57 @@ def test_epoch(model, loader, print_freq=1, is_test=True):
88
85
model .eval ()
89
86
90
87
end = time .time ()
91
- for batch_idx , (input , target ) in enumerate (loader ):
92
- # Create vaiables
93
- if torch .cuda .is_available ():
94
- input_var = torch .autograd .Variable (input .cuda (async = True ), volatile = True )
95
- target_var = torch .autograd .Variable (target .cuda (async = True ), volatile = True )
96
- else :
97
- input_var = torch .autograd .Variable (input , volatile = True )
98
- target_var = torch .autograd .Variable (target , volatile = True )
99
-
100
- # compute output
101
- output = model (input_var )
102
- loss = torch .nn .functional .cross_entropy (output , target_var )
103
-
104
- # measure accuracy and record loss
105
- batch_size = target .size (0 )
106
- _ , pred = output .data .cpu ().topk (1 , dim = 1 )
107
- error .update (torch .ne (pred .squeeze (), target .cpu ()).float ().sum () / batch_size , batch_size )
108
- losses .update (loss .data [0 ], batch_size )
109
-
110
- # measure elapsed time
111
- batch_time .update (time .time () - end )
112
- end = time .time ()
113
-
114
- # print stats
115
- if batch_idx % print_freq == 0 :
116
- res = '\t ' .join ([
117
- 'Test' if is_test else 'Valid' ,
118
- 'Iter: [%d/%d]' % (batch_idx + 1 , len (loader )),
119
- 'Time %.3f (%.3f)' % (batch_time .val , batch_time .avg ),
120
- 'Loss %.4f (%.4f)' % (losses .val , losses .avg ),
121
- 'Error %.4f (%.4f)' % (error .val , error .avg ),
122
- ])
123
- print (res )
88
+ with torch .no_grad ():
89
+ for batch_idx , (input , target ) in enumerate (loader ):
90
+ # Create vaiables
91
+ if torch .cuda .is_available ():
92
+ input = input .cuda ()
93
+ target = target .cuda ()
94
+
95
+ # compute output
96
+ output = model (input )
97
+ loss = torch .nn .functional .cross_entropy (output , target )
98
+
99
+ # measure accuracy and record loss
100
+ batch_size = target .size (0 )
101
+ _ , pred = output .data .cpu ().topk (1 , dim = 1 )
102
+ error .update (torch .ne (pred .squeeze (), target .cpu ()).float ().sum ().item () / batch_size , batch_size )
103
+ losses .update (loss .item (), batch_size )
104
+
105
+ # measure elapsed time
106
+ batch_time .update (time .time () - end )
107
+ end = time .time ()
108
+
109
+ # print stats
110
+ if batch_idx % print_freq == 0 :
111
+ res = '\t ' .join ([
112
+ 'Test' if is_test else 'Valid' ,
113
+ 'Iter: [%d/%d]' % (batch_idx + 1 , len (loader )),
114
+ 'Time %.3f (%.3f)' % (batch_time .val , batch_time .avg ),
115
+ 'Loss %.4f (%.4f)' % (losses .val , losses .avg ),
116
+ 'Error %.4f (%.4f)' % (error .val , error .avg ),
117
+ ])
118
+ print (res )
124
119
125
120
# Return summary statistics
126
121
return batch_time .avg , losses .avg , error .avg
127
122
128
123
129
- def train (model , train_set , test_set , save , n_epochs = 300 , valid_size = 5000 ,
124
+ def train (model , train_set , valid_set , test_set , save , n_epochs = 300 ,
130
125
batch_size = 64 , lr = 0.1 , wd = 0.0001 , momentum = 0.9 , seed = None ):
131
126
if seed is not None :
132
127
torch .manual_seed (seed )
133
128
134
- # Create train/valid split
135
- if valid_size :
136
- indices = torch .randperm (len (train_set ))
137
- train_indices = indices [:len (indices ) - valid_size ]
138
- train_sampler = torch .utils .data .sampler .SubsetRandomSampler (train_indices )
139
- valid_indices = indices [len (indices ) - valid_size :]
140
- valid_sampler = torch .utils .data .sampler .SubsetRandomSampler (valid_indices )
141
-
142
129
# Data loaders
130
+ train_loader = torch .utils .data .DataLoader (train_set , batch_size = batch_size , shuffle = True ,
131
+ pin_memory = (torch .cuda .is_available ()), num_workers = 0 )
143
132
test_loader = torch .utils .data .DataLoader (test_set , batch_size = batch_size , shuffle = False ,
144
133
pin_memory = (torch .cuda .is_available ()), num_workers = 0 )
145
- if valid_size :
146
- train_loader = torch .utils .data .DataLoader (train_set , batch_size = batch_size , sampler = train_sampler ,
147
- pin_memory = (torch .cuda .is_available ()), num_workers = 0 )
148
- valid_loader = torch .utils .data .DataLoader (train_set , batch_size = batch_size , sampler = valid_sampler ,
149
- pin_memory = (torch .cuda .is_available ()), num_workers = 0 )
134
+ if valid_set is None :
135
+ valid_loader = None
150
136
else :
151
- train_loader = torch .utils .data .DataLoader (train_set , batch_size = batch_size , shuffle = True ,
137
+ valid_loader = torch .utils .data .DataLoader (valid_set , batch_size = batch_size , shuffle = False ,
152
138
pin_memory = (torch .cuda .is_available ()), num_workers = 0 )
153
- valid_loader = None
154
-
155
139
# Model on cuda
156
140
if torch .cuda .is_available ():
157
141
model = model .cuda ()
@@ -264,6 +248,16 @@ def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
264
248
train_set = datasets .CIFAR10 (data , train = True , transform = train_transforms , download = True )
265
249
test_set = datasets .CIFAR10 (data , train = False , transform = test_transforms , download = False )
266
250
251
+ if valid_size :
252
+ valid_set = datasets .CIFAR10 (data , train = True , transform = test_transforms )
253
+ indices = torch .randperm (len (train_set ))
254
+ train_indices = indices [:len (indices ) - valid_size ]
255
+ valid_indices = indices [len (indices ) - valid_size :]
256
+ train_set = torch .utils .data .Subset (train_set , train_indices )
257
+ valid_set = torch .utils .data .Subset (valid_set , valid_indices )
258
+ else :
259
+ valid_set = None
260
+
267
261
# Models
268
262
model = DenseNet (
269
263
growth_rate = growth_rate ,
@@ -281,8 +275,8 @@ def demo(data, save, depth=100, growth_rate=12, efficient=True, valid_size=5000,
281
275
raise Exception ('%s is not a dir' % save )
282
276
283
277
# Train the model
284
- train (model = model , train_set = train_set , test_set = test_set , save = save ,
285
- valid_size = valid_size , n_epochs = n_epochs , batch_size = batch_size , seed = seed )
278
+ train (model = model , train_set = train_set , valid_set = valid_set , test_set = test_set , save = save ,
279
+ n_epochs = n_epochs , batch_size = batch_size , seed = seed )
286
280
print ('Done!' )
287
281
288
282
0 commit comments