-
Notifications
You must be signed in to change notification settings - Fork 8
/
DMSPDeblur.m
98 lines (77 loc) · 3.06 KB
/
DMSPDeblur.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
function res = DMSPDeblur(degraded, kernel, sigma_d, params)
% Implements stochastic gradient descent (SGD) Bayes risk minimization for image deblurring described in:
% "Deep Mean-Shift Priors for Image Restoration" (http://home.inf.unibe.ch/~bigdeli/DMSPrior.html)
% S. A. Bigdeli, M. Jin, P. Favaro, M. Zwicker, Advances in Neural Information Processing Systems (NIPS), 2017
%
% Input:
% degraded: Observed degraded RGB input image in range of [0, 255].
% kernel: Blur kernel (internally flipped for convolution).
% sigma_d: Noise standard deviation. (set to -1 for noise-blind deblurring)
% params: Set of parameters.
% params.denoiser: The denoiser function hanlde.
%
% Optional parameters:
% params.sigma_dae: The standard deviation of the denoiser training noise. default: 11
% params.num_iter: Specifies number of iterations.
% params.mu: The momentum for SGD optimization. default: 0.9
% params.alpha the step length in SGD optimization. default: 0.1
%
%
% Outputs:
% res: Solution.
if ~any(strcmp('denoiser',fieldnames(params)))
error('Need a denoiser in params.denoiser!');
end
if ~any(strcmp('sigma_dae',fieldnames(params)))
params.sigma_dae = 11;
end
if ~any(strcmp('num_iter',fieldnames(params)))
params.num_iter = 300;
end
if ~any(strcmp('mu',fieldnames(params)))
params.mu = .9;
end
if ~any(strcmp('alpha',fieldnames(params)))
params.alpha = .1;
end
print_iter = any(strcmp('gt',fieldnames(params)));
pad = floor(size(kernel)/2);
res = padarray(degraded, pad, 'replicate', 'both');
step = zeros(size(res));
if print_iter
psnr = computePSNR(params.gt, res, pad);
disp(['Initialized with PSNR: ' num2str(psnr)]);
end
for iter = 1:params.num_iter
if print_iter
disp(['Running iteration: ' num2str(iter)]);
tic();
end
% compute prior gradient
input = res(:,:,[3,2,1]); % Switch channels for network
noise = randn(size(input)) * params.sigma_dae;
rec = params.denoiser(input + noise);
prior_grad = input - rec;
prior_grad = prior_grad(:,:,[3,2,1]);
% compute data gradient
map_conv = convn(res,rot90(kernel,2),'valid');
data_err = map_conv-degraded;
data_grad = convn(data_err,kernel,'full');
if sigma_d<0
sigma2 = 2*params.sigma_dae*params.sigma_dae;
lambda = (numel(degraded))/(sum(data_err(:).^2) + numel(degraded)*sigma2*sum(kernel(:).^2));
relative_weight = (lambda)/(lambda + 1/params.sigma_dae/params.sigma_dae);
else
relative_weight = (1/sigma_d/sigma_d)/(1/sigma_d/sigma_d + 1/params.sigma_dae/params.sigma_dae);
end
% sum the gradients
grad_joint = data_grad*relative_weight + prior_grad*(1-relative_weight);
% update
step = params.mu * step - params.alpha * grad_joint;
res = res + step;
res = min(255,max(0,res));
if print_iter
psnr = computePSNR(params.gt, res, pad);
disp(['PSNR is: ' num2str(psnr) ', iteration finished in ' num2str(toc()) ' seconds']);
end
end