@@ -151,6 +151,8 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
151
151
ref_model .train (False )
152
152
ref_model .to (device )
153
153
154
+ #print (ref_model)
155
+
154
156
for param in model_init .Tmodel .classifier .parameters ():
155
157
param .requires_grad = True
156
158
@@ -176,14 +178,18 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
176
178
#Actually makes the changes to the model_init, so slightly redundant
177
179
print ("Initializing the model to be trained" )
178
180
model_init = initialize_new_model (model_init , num_classes , num_of_classes_old )
179
- model_init .to (device )
181
+ #print (model_init)
182
+ #model_init.to(device)
180
183
start_epoch = 0
181
184
182
185
#The training process format or LwF (Learning without Forgetting)
183
186
# Add the start epoch code
184
187
185
188
if (best_relatedness > 0.85 ):
186
189
190
+ model_init .to (device )
191
+ ref_model .to (device )
192
+
187
193
print ("Using the LwF approach" )
188
194
for epoch in range (start_epoch , num_epochs ):
189
195
since = time .time ()
@@ -197,7 +203,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
197
203
198
204
#scales the optimizer every 10 epochs
199
205
optimizer = exp_lr_scheduler (optimizer , epoch , lr )
200
- model_init = model_init .train (True )
206
+ # model_init = model_init.train(True)
201
207
202
208
for data in dset_loaders :
203
209
input_data , labels = data
@@ -212,32 +218,27 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
212
218
input_data = Variable (input_data )
213
219
labels = Variable (labels )
214
220
215
- model_init .to (device )
216
- ref_model .to (device )
217
-
218
221
output = model_init (input_data )
219
222
ref_output = ref_model (input_data )
220
-
221
223
del input_data
222
224
223
225
optimizer .zero_grad ()
224
- model_init .zero_grad ()
225
226
226
227
# loss_1 only takes in the outputs from the nodes of the old classes
227
228
228
229
loss1_output = output [:, :num_of_classes_old ]
229
230
loss2_output = output [:, num_of_classes_old :]
230
231
232
+ print ()
233
+
231
234
del output
232
235
233
236
loss_1 = model_criterion (loss1_output , ref_output , flag = "Distill" )
234
-
235
237
del ref_output
236
238
237
239
# loss_2 takes in the outputs from the nodes that were initialized for the new task
238
240
239
241
loss_2 = model_criterion (loss2_output , labels , flag = "CE" )
240
-
241
242
del labels
242
243
#del output
243
244
@@ -257,7 +258,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
257
258
258
259
print ('Epoch Loss:{}' .format (epoch_loss ))
259
260
260
- if (epoch != 0 and epoch != num_of_epochs - 1 and (epoch + 1 ) % 10 == 0 ):
261
+ if (epoch != 0 and epoch != num_epochs - 1 and (epoch + 1 ) % 10 == 0 ):
261
262
epoch_file_name = os .path .join (mypath , str (epoch + 1 )+ '.pth.tar' )
262
263
torch .save ({
263
264
'epoch' : epoch ,
@@ -277,6 +278,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
277
278
#Process for finetuning the model
278
279
else :
279
280
281
+ model_init .to (device )
280
282
print ("Using the finetuning approach" )
281
283
282
284
for epoch in range (start_epoch , num_epochs ):
@@ -302,9 +304,6 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
302
304
input_data = Variable (input_data )
303
305
labels = Variable (labels )
304
306
305
- #Shifts the model to the device
306
- model_init .to (device )
307
-
308
307
output = model_init (input_data )
309
308
310
309
del input_data
@@ -314,7 +313,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
314
313
model_init .zero_grad ()
315
314
316
315
#Implemented as explained in the doc string
317
- loss = model_criterion (output [num_of_classes_old :], labels )
316
+ loss = model_criterion (output [num_of_classes_old :], labels , flag = 'CE' )
318
317
319
318
del output
320
319
del labels
@@ -330,7 +329,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
330
329
331
330
print ('Epoch Loss:{}' .format (epoch_loss ))
332
331
333
- if (epoch != 0 and (epoch + 1 ) % 5 == 0 and epoch != num_of_epochs - 1 ):
332
+ if (epoch != 0 and (epoch + 1 ) % 5 == 0 and epoch != num_epochs - 1 ):
334
333
epoch_file_name = os .path .join (path_to_model , str (epoch + 1 )+ '.pth.tar' )
335
334
torch .save ({
336
335
'epoch' : epoch ,
0 commit comments