-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathModel.py
40 lines (27 loc) · 1.08 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
import os
import glob
class Model(nn.Module):
def __init__(self, name):
super(Model, self).__init__()
self.name = name
def save(self, path, epoch=0):
complete_path = os.path.join(path, self.name)
if not os.path.exists(complete_path):
os.makedirs(complete_path)
torch.save(self.state_dict(),
os.path.join(complete_path,
"model-{}.pth".format(str(epoch).zfill(5))))
def save_results(self, path, data):
raise NotImplementedError("Model subclass must implement this method.")
def load(self, path, modelfile=None):
complete_path = os.path.join(path, self.name)
if not os.path.exists(complete_path):
raise IOError("{} directory does not exist in {}".format(self.name, path))
if modelfile is None:
model_files = glob.glob(complete_path+"/*")
mf = max(model_files)
else:
mf = os.path.join(complete_path, modelfile)
self.load_state_dict(torch.load(mf))