-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
87 lines (73 loc) · 3.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import paddle
from math import ceil
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
print('Loading pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model):
para_state_dict = paddle.load(pretrained_model)
para_state_dict = para_state_dict['generator']
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
print("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
print(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
print("There are {}/{} variables loaded into {}.".format(
num_params_loaded, len(model_state_dict),
model.__class__.__name__))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.format(
pretrained_model))
else:
print(
'No pretrained model to load, {} will be trained from scratch.'.
format(model.__class__.__name__))
def chop_forward(model, inp, shave=8, min_size=800000):
_, _, h, w = inp.shape
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
mod_size = 4
if h_size%mod_size:
h_size = ceil(h_size/mod_size)*mod_size # The ceil() function returns the uploaded integer of a number
if w_size%mod_size:
w_size = ceil(w_size/mod_size)*mod_size
inputlist = [
inp[:, :, 0:h_size, 0:w_size],
inp[:, :, 0:h_size, (w - w_size):w],
inp[:, :, (h - h_size):h, 0:w_size],
inp[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
outputlist = []
for i in range(4):
with paddle.no_grad():
input_batch = inputlist[i]
output_refine, _, _ = model(input_batch)
outputlist.append(output_refine)
else:
outputlist = [
chop_forward(model, patch) \
for patch in inputlist]
scale=1
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
with paddle.no_grad():
output_ht = paddle.zeros_like(inp)
output_ht[:, :, 0:h_half, 0:w_half] = outputlist[0][:, :, 0:h_half, 0:w_half]
output_ht[:, :, 0:h_half, w_half:w] = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output_ht[:, :, h_half:h, 0:w_half] = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output_ht[:, :, h_half:h, w_half:w] = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output_ht