forked from remotebiosensing/rppg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
57 lines (50 loc) · 1.89 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torchinfo
import torchsummary
from log import log_warning, log_info
from nets.models.DeepPhys import DeepPhys
from nets.models.DeepPhys_DA import DeepPhys_DA
from nets.models.PhysNet import PhysNet
from nets.models.PhysNet import PhysNet_2DCNN_LSTM
def get_model(model_name: str = "DeepPhys"):
"""
:param model_name: model name
:return: model
"""
if model_name == "DeepPhys":
return DeepPhys()
elif model_name == "DeepPhys_DA":
return DeepPhys_DA()
elif model_name == "PhysNet":
return PhysNet()
elif model_name == "PhysNet_LSTM":
return PhysNet_2DCNN_LSTM()
else:
log_warning("use implemented model")
raise NotImplementedError("implement a custom model(%s) in /nets/models/" % model_name)
def is_model_support(model_name, model_list):
"""
:param model_name: model name
:param model_list: implemented model list
:return: model
"""
if not (model_name in model_list):
log_warning("use implemented model")
raise NotImplementedError("implement a custom model(%s) in /nets/models/" % model_name)
def summary(model, model_name):
"""
:param model: torch.nn.module class
:param model_name: implemented model name
:return: model
"""
log_info("=========================================")
log_info(model_name)
log_info("=========================================")
if model_name == "DeepPhys" or model_name == DeepPhys_DA:
torchsummary.summary(model, (2, 3, 36, 36))
elif model_name == "PhysNet" or model_name == "PhysNet_LSTM":
#Use torchinfo support recursive layers(RNN,LSTM)
#torchinfo: (model_name, input_size=(batch_size, input_size))
torchinfo.summary(model,(1, 3, 32, 128, 128))
else:
log_warning("use implemented model")
raise NotImplementedError("implement a custom model(%s) in /nets/models/" % model_name)