forked from jindongwang/transferlearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMyJDA.m
More file actions
143 lines (121 loc) · 4.02 KB
/
MyJDA.m
File metadata and controls
143 lines (121 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
function [acc,acc_ite,A] = MyJDA(X_src,Y_src,X_tar,Y_tar,options)
% Inputs:
%%% X_src :source feature matrix, ns * m
%%% Y_src :source label vector, ns * 1
%%% X_tar :target feature matrix, nt * m
%%% Y_tar :target label vector, nt * 1
%%% options:option struct
% Outputs:
%%% acc :final accuracy using knn, float
%%% acc_ite:list of all accuracies during iterations
%%% A :final adaptation matrix, (ns + nt) * (ns + nt)
%% Set options
lambda = options.lambda; %% lambda for the regularization
dim = options.dim; %% dim is the dimension after adaptation, dim <= m
kernel_type = options.kernel_type; %% kernel_type is the kernel name, primal|linear|rbf
gamma = options.gamma; %% gamma is the bandwidth of rbf kernel
T = options.T; %% iteration number
acc_ite = [];
Y_tar_pseudo = [];
%% Iteration
for i = 1 : T
[Z,A] = JDA_core(X_src,Y_src,X_tar,Y_tar_pseudo,options);
%normalization for better classification performance
Z = Z*diag(sparse(1./sqrt(sum(Z.^2))));
Zs = Z(:,1:size(X_src,1));
Zt = Z(:,size(X_src,1)+1:end);
knn_model = fitcknn(Zs',Y_src,'NumNeighbors',1);
Y_tar_pseudo = knn_model.predict(Zt');
acc = length(find(Y_tar_pseudo==Y_tar))/length(Y_tar);
fprintf('JDA+NN=%0.4f\n',acc);
acc_ite = [acc_ite;acc];
end
end
function [Z,A] = JDA_core(X_src,Y_src,X_tar,Y_tar_pseudo,options)
%% Set options
lambda = options.lambda; %% lambda for the regularization
dim = options.dim; %% dim is the dimension after adaptation, dim <= m
kernel_type = options.kernel_type; %% kernel_type is the kernel name, primal|linear|rbf
gamma = options.gamma; %% gamma is the bandwidth of rbf kernel
%% Construct MMD matrix
X = [X_src',X_tar'];
X = X*diag(sparse(1./sqrt(sum(X.^2))));
[m,n] = size(X);
ns = size(X_src,1);
nt = size(X_tar,1);
e = [1/ns*ones(ns,1);-1/nt*ones(nt,1)];
C = length(unique(Y_src));
%%% M0
M = e * e' * C; %multiply C for better normalization
%%% Mc
N = 0;
if ~isempty(Y_tar_pseudo) && length(Y_tar_pseudo)==nt
for c = reshape(unique(Y_src),1,C)
e = zeros(n,1);
e(Y_src==c) = 1 / length(find(Y_src==c));
e(ns+find(Y_tar_pseudo==c)) = -1 / length(find(Y_tar_pseudo==c));
e(isinf(e)) = 0;
N = N + e*e';
end
end
M = M + N;
M = M / norm(M,'fro');
%% Centering matrix H
H = eye(n) - 1/n * ones(n,n);
%% Calculation
if strcmp(kernel_type,'primal')
[A,~] = eigs(X*M*X'+lambda*eye(m),X*H*X',dim,'SM');
Z = A'*X;
else
K = kernel_jda(kernel_type,X,[],gamma);
[A,~] = eigs(K*M*K'+lambda*eye(n),K*H*K',dim,'SM');
Z = A'*K;
end
end
% With Fast Computation of the RBF kernel matrix
% To speed up the computation, we exploit a decomposition of the Euclidean distance (norm)
%
% Inputs:
% ker: 'linear','rbf','sam'
% X: data matrix (features * samples)
% gamma: bandwidth of the RBF/SAM kernel
% Output:
% K: kernel matrix
%
% Gustavo Camps-Valls
% 2006(c)
% Jordi (jordi@uv.es), 2007
% 2007-11: if/then -> switch, and fixed RBF kernel
% Modified by Mingsheng Long
% 2013(c)
% Mingsheng Long (longmingsheng@gmail.com), 2013
function K = kernel_jda(ker,X,X2,gamma)
switch ker
case 'linear'
if isempty(X2)
K = X'*X;
else
K = X'*X2;
end
case 'rbf'
n1sq = sum(X.^2,1);
n1 = size(X,2);
if isempty(X2)
D = (ones(n1,1)*n1sq)' + ones(n1,1)*n1sq -2*X'*X;
else
n2sq = sum(X2.^2,1);
n2 = size(X2,2);
D = (ones(n2,1)*n1sq)' + ones(n1,1)*n2sq -2*X'*X2;
end
K = exp(-gamma*D);
case 'sam'
if isempty(X2)
D = X'*X;
else
D = X'*X2;
end
K = exp(-gamma*acos(D).^2);
otherwise
error(['Unsupported kernel ' ker])
end
end