Skip to content

Commit 855cccc

Browse files
committed
add ability to resume training
1 parent 44aa961 commit 855cccc

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

train.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class ModelParams(NamedTuple):
1212
batch_size: int = 32
1313
device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
1414
epochs: int = 15
15+
resume_path: str = None
1516

1617
# lstm
1718
hidden_size: int = 2
@@ -60,8 +61,12 @@ def train(params: ModelParams):
6061
test_loader = DataLoader(XORDataset(train=False), batch_size=params.batch_size)
6162

6263
step = 0
64+
epoch = 1
6365

64-
for epoch in range(1, params.epochs):
66+
if params.resume_path:
67+
step, epoch = resume_train_state(params.resume_path, model, optimizer)
68+
69+
for epoch in range(epoch, params.epochs):
6570
for inputs, targets in train_loader:
6671
inputs = inputs.to(params.device)
6772
targets = targets.to(params.device)
@@ -84,6 +89,25 @@ def train(params: ModelParams):
8489
# evaluate per epoch
8590
evaluate(model, test_loader)
8691

92+
save_train_state(step, epoch, model, optimizer)
93+
94+
95+
def resume_train_state(path, model, optimizer):
96+
state = torch.load(path)
97+
model.load_state_dict(state['model'])
98+
optimizer.load_state_dict(state['optimizer'])
99+
return state['step'], state['epoch']
100+
101+
102+
def save_train_state(step, epoch, model, optimizer):
103+
state = {
104+
'epoch': epoch + 1,
105+
'model': model.state_dict(),
106+
'optimizer': optimizer.state_dict(),
107+
'step': step
108+
}
109+
torch.save(state, f'./data/epoch_{epoch}.pt')
110+
87111

88112
def evaluate(model, loader):
89113
is_correct = np.array([])

0 commit comments

Comments
 (0)