forked from lelegan/LP_KSVD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinitialization4LPKSVD.m
81 lines (71 loc) · 2.84 KB
/
initialization4LPKSVD.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
% ========================================================================
% Initialization for Locality preserving KSVD algorithm
% USAGE: [Dinit,Tinit,Winit,Q] = initialization4LPKSVD(training_feats,....
% H_train,dictsize,iterations,sparsitythres,Q_in,num)
% Inputs
% training_feats -training features
% H_train -label matrix for training feature
% dictsize -number of dictionary items
% iterations -iterations
% sparsitythres -sparsity threshold
% Q_in -initialized locality preserving matrix
% num -iteration number
% Outputs
% Dinit -initialized dictionary
% Tinit -initialized linear transform matrix
% Winit -initialized classifier parameters
% Q -computed locality preserving matrix
%
% Author: Weiyang Liu (wyliu@pku.edu.cn)
% Date: 11-20-2014
% ========================================================================
function [Dinit,Tinit,Winit,Q]=initialization4LPKSVD(training_feats,H_train,dictsize,iterations,sparsitythres,Q_in,num)
numClass = size(H_train,1); % number of objects
numPerClass = round(dictsize/numClass); % initial points from each classes
Dinit = [];
dictLabel = [];
for classid=1:numClass
col_ids = find(H_train(classid,:)==1);
data_ids = find(colnorms_squared_new(training_feats(:,col_ids)) > 1e-6); % ensure no zero data elements are chosen
% perm = randperm(length(data_ids));
perm = [1:length(data_ids)];
%%% Initilization for LP-KSVD (perform KSVD in each class)
Dpart = training_feats(:,col_ids(data_ids(perm(1:numPerClass))));
para.data = training_feats(:,col_ids(data_ids));
para.Tdata = sparsitythres;
para.iternum = iterations;
para.memusage = 'high';
% normalization
para.initdict = normcols(Dpart);
% ksvd process
[Dpart,Xpart,Errpart] = ksvd(para,'');
Dinit = [Dinit Dpart];
labelvector = zeros(numClass,1);
labelvector(classid) = 1;
dictLabel = [dictLabel repmat(labelvector,1,numPerClass)];
end
% Q (label-constraints code); T: scale factor
T = eye(dictsize,dictsize); % scale factor
if num==1
% Q=Qinput;
Q = zeros(dictsize,size(training_feats,2)); % energy matrix
else
% for frameid=1:size(training_feats,2)
% [dists,neighbors] = top_K_neighbors( training_feats,training_feats(:,frameid),10 );
% Q(neighbors,frameid)=1;
% end
Q=Q_in;
end
params.data = training_feats;
params.Tdata = sparsitythres; % spasity term
params.iternum = iterations;
params.memusage = 'high';
% normalization
params.initdict = normcols(Dinit);
% ksvd process
[Dtemp,Xtemp,Errtemp] = ksvd(params,'');
% learning linear classifier parameters
Winit = inv(Xtemp*Xtemp'+eye(size(Xtemp*Xtemp')))*Xtemp*H_train';
Winit = Winit';
Tinit = inv(Xtemp*Xtemp'+eye(size(Xtemp*Xtemp')))*Xtemp*Q';
Tinit = Tinit';