-
Notifications
You must be signed in to change notification settings - Fork 5
/
demo.py
63 lines (45 loc) · 2.08 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
'''
Copyright (c) 2022 by Haiming Zhang. All Rights Reserved.
Author: Haiming Zhang
Date: 2022-04-15 13:27:07
Email: haimingzhang@link.cuhk.edu.cn
Description: demo
'''
import argparse
import pytorch_lightning as pl
from dataset import get_3dmm_dataset, get_test_dataset
from omegaconf import OmegaConf
from utils.utils import get_git_commit_id
from models import get_model
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='./config/demo.yaml', help='the config file path')
parser.add_argument('--gpu', type=int, nargs='+', default=(0, 1), help='specify gpu devices')
parser.add_argument('--checkpoint_dir', type=str, nargs='?', const="work_dir2/debug")
parser.add_argument('--checkpoint', type=str, default=None, help="the pretrained checkpoint path")
parser.add_argument('--test_mode', action='store_true', help="whether is a test mode")
args = parser.parse_args()
config = OmegaConf.load(args.cfg)
if args.checkpoint_dir is None: # use the yaml value if don't specify the checkpoint_dir argument
args.checkpoint_dir = config.checkpoint_dir
config.update(vars(args)) # override the configuration using the value in args
try:
config['commit_id'] = get_git_commit_id()
except:
print("[WARNING] Couldn't get the git commit id")
print(OmegaConf.to_yaml(config, resolve=True))
return config
config = parse_config()
## Create model
model = get_model(config['model_name'], config)
if config.checkpoint is None:
print(f"[WARNING] Train from scratch!")
else:
print(f"[WARNING] Load pretrained model from {config.checkpoint}")
model = model.load_from_checkpoint(config.checkpoint, config=config)
print(f"{'='*25} Start Testing, Good Luck! {'='*25}")
# test_dataloader = get_3dmm_dataset(config['dataset'], split="voca_test", shuffle=False)
test_dataloader = get_test_dataset(config['dataset'])
print(f"The testing dataloader length is {len(test_dataloader)}")
trainer = pl.Trainer(gpus=1, default_root_dir=config['checkpoint_dir'])
trainer.test(model, test_dataloader)