forked from cfchen-duke/ProtoPNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
preprocess.py
33 lines (27 loc) · 820 Bytes
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
def preprocess(x, mean, std):
assert x.size(1) == 3
y = torch.zeros_like(x)
for i in range(3):
y[:, i, :, :] = (x[:, i, :, :] - mean[i]) / std[i]
return y
def preprocess_input_function(x):
'''
allocate new tensor like x and apply the normalization used in the
pretrained model
'''
return preprocess(x, mean=mean, std=std)
def undo_preprocess(x, mean, std):
assert x.size(1) == 3
y = torch.zeros_like(x)
for i in range(3):
y[:, i, :, :] = x[:, i, :, :] * std[i] + mean[i]
return y
def undo_preprocess_input_function(x):
'''
allocate new tensor like x and undo the normalization used in the
pretrained model
'''
return undo_preprocess(x, mean=mean, std=std)