Skip to content

Commit b758e3c

Browse files
b21527616chsasank
authored andcommitted
Fixed deep copying model.state_dict() (pytorch#188)
Otherwise best_model_wts changes always to last epoch's state
1 parent 016d0cc commit b758e3c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

beginner_source/transfer_learning_tutorial.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import matplotlib.pyplot as plt
4747
import time
4848
import os
49+
import copy
4950

5051
plt.ion() # interactive mode
5152

@@ -144,7 +145,7 @@ def imshow(inp, title=None):
144145
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
145146
since = time.time()
146147

147-
best_model_wts = model.state_dict()
148+
best_model_wts = copy.deepcopy(model.state_dict())
148149
best_acc = 0.0
149150

150151
for epoch in range(num_epochs):
@@ -200,7 +201,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
200201
# deep copy the model
201202
if phase == 'val' and epoch_acc > best_acc:
202203
best_acc = epoch_acc
203-
best_model_wts = model.state_dict()
204+
best_model_wts = copy.deepcopy(model.state_dict())
204205

205206
print()
206207

0 commit comments

Comments
 (0)