-
Notifications
You must be signed in to change notification settings - Fork 10
/
run_test.py
38 lines (31 loc) · 1.24 KB
/
run_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch, cv2, os
import numpy as np
import arg, models
from test import tester
print(os.sys.argv)
# Model
model = torch.nn.DataParallel(
models.get_model(arg.model_name)(**arg.model_args)
).cuda()
ckpt = torch.load(f'{arg.ckpt_path}/{int(os.sys.argv[3])}.pth')
model.module.load_state_dict(ckpt['model_state_dict'])
# Light
dataset_color = np.load(f'{arg.base_path}/data/dataset_color.npy')
light_size = (8, 16)
##### Test #####
if os.sys.argv[2].startswith('validate'):
light = cv2.imread(f'{arg.base_path}/data_test/{os.sys.argv[2]}/source_image_masked/target_light.hdr', cv2.IMREAD_UNCHANGED)
light = cv2.resize(
cv2.cvtColor(light, cv2.COLOR_BGR2RGB), tuple(light_size[::-1]), interpolation=cv2.INTER_AREA
)
cam_poses = [(15, 22.5*i) for i in range(16)] + [(15, 0) for i in range(32)]
lights = [0 for i in range(16)] + [15-i for i in range(16)] + [np.roll(light, i, axis=1) for i in range(16)]
filenames = [f'{i}' for i in range(48)]
tester(
data_name=os.sys.argv[2], image_ids=[0, 1, 2, 3, 4],
dataset_color=None, extra_color_scale=1.0,
cam_poses=cam_poses,
lights=lights,
filenames=filenames,
model=model, test_path=f'{arg.base_path}/data_test', folder='masked'
)