-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinfer.py
67 lines (55 loc) · 1.94 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torchvision
from torchvision import transforms
from glob import glob
import os
from PIL import Image
import time
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils.util import *
model_name = 'DSLR'
in_path = './input/test.png'
out_path = './output/' + model_name
def load_model():
from model import Model
model = Model()
model.load_weight('weights/DSLR/weight.pth')
return model.eval().cuda()
def load_data_paths():
global in_path, out_path
if os.path.isfile(in_path):
input_paths = [in_path]
in_path = os.path.dirname(in_path)
elif os.path.isdir(in_path):
input_paths = []
for root, dirs, files in os.walk(in_path):
for name in files:
for ext in ['.jpg', '.png', '.jpeg', '.bmp']:
if name.lower().endswith(ext):
input_paths.append(os.path.join(root, name))
return input_paths
def inference(model, input_paths):
global in_path, out_path
total_time = 0
ts = transforms.ToTensor()
with torch.no_grad():
for input_path in input_paths:
output_path = input_path.replace(in_path, out_path)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
img = Image.open(input_path)
img = ts(img).unsqueeze(0).cuda()
img, h, w = padding(img, 256)
tic = time.time()
output = model(img)
toc = time.time()
output = unpadding(output, h, w)
total_time += toc - tic
torchvision.utils.save_image(output, output_path)
print('{} Total time: {:.4f}s Speed: {:.4f}s/img'.format(model_name, total_time, total_time / len(input_paths)))
if __name__ == '__main__':
model = load_model()
input_paths = load_data_paths()
inference(model, input_paths)