@@ -12,6 +12,7 @@ class ModelParams(NamedTuple):
12
12
batch_size : int = 32
13
13
device : str = 'cuda:0' if torch .cuda .is_available () else 'cpu'
14
14
epochs : int = 15
15
+ resume_path : str = None
15
16
16
17
# lstm
17
18
hidden_size : int = 2
@@ -60,8 +61,12 @@ def train(params: ModelParams):
60
61
test_loader = DataLoader (XORDataset (train = False ), batch_size = params .batch_size )
61
62
62
63
step = 0
64
+ epoch = 1
63
65
64
- for epoch in range (1 , params .epochs ):
66
+ if params .resume_path :
67
+ step , epoch = resume_train_state (params .resume_path , model , optimizer )
68
+
69
+ for epoch in range (epoch , params .epochs ):
65
70
for inputs , targets in train_loader :
66
71
inputs = inputs .to (params .device )
67
72
targets = targets .to (params .device )
@@ -84,6 +89,25 @@ def train(params: ModelParams):
84
89
# evaluate per epoch
85
90
evaluate (model , test_loader )
86
91
92
+ save_train_state (step , epoch , model , optimizer )
93
+
94
+
95
+ def resume_train_state (path , model , optimizer ):
96
+ state = torch .load (path )
97
+ model .load_state_dict (state ['model' ])
98
+ optimizer .load_state_dict (state ['optimizer' ])
99
+ return state ['step' ], state ['epoch' ]
100
+
101
+
102
+ def save_train_state (step , epoch , model , optimizer ):
103
+ state = {
104
+ 'epoch' : epoch + 1 ,
105
+ 'model' : model .state_dict (),
106
+ 'optimizer' : optimizer .state_dict (),
107
+ 'step' : step
108
+ }
109
+ torch .save (state , f'./data/epoch_{ epoch } .pt' )
110
+
87
111
88
112
def evaluate (model , loader ):
89
113
is_correct = np .array ([])
0 commit comments